|
2 | 2 | #define UNION_FIND
|
3 | 3 |
|
4 | 4 | #include <vector>
|
| 5 | + |
5 | 6 | using namespace std;
|
6 | 7 |
|
7 |
| -class union_find |
8 |
| -{ |
| 8 | +class union_find { |
9 | 9 | private:
|
10 | 10 | vector<int> parent;
|
11 | 11 | vector<int> rank;
|
| 12 | + vector<int> cnt; |
12 | 13 | public:
|
13 |
| - union_find() |
14 |
| - { |
| 14 | + union_find() { |
15 | 15 | parent.clear();
|
16 | 16 | rank.clear();
|
| 17 | + cnt.clear(); |
17 | 18 | }
|
18 |
| - |
19 |
| - union_find(int n) |
20 |
| - { |
21 |
| - parent=vector<int>(n); |
22 |
| - rank=vector<int>(n,0); |
| 19 | + |
| 20 | + explicit union_find(int n) { |
| 21 | + parent = vector<int>(n); |
| 22 | + rank = vector<int>(n, 0); |
| 23 | + cnt = vector<int>(n, 0); |
23 | 24 | init(n);
|
24 | 25 | }
|
25 | 26 |
|
26 |
| - void init(int n) |
27 |
| - { |
28 |
| - for (int i=0;i<n;++i) |
29 |
| - { |
30 |
| - parent.at(i)=i; |
31 |
| - rank.at(i)=0; |
| 27 | + void init(int n) { |
| 28 | + for (int i = 0; i < n; ++i) { |
| 29 | + parent.at(i) = i; |
| 30 | + rank.at(i) = 0; |
| 31 | + cnt.at(i) = 1; |
32 | 32 | }
|
33 | 33 | }
|
34 | 34 |
|
35 |
| - int find_root(int i) |
36 |
| - { |
37 |
| - if (parent.at(i)!=i) |
38 |
| - { |
39 |
| - int root=find_root(parent.at(i)); |
40 |
| - parent.at(i)=root; |
| 35 | + int find_root(int i) { |
| 36 | + if (parent.at(i) != i) { |
| 37 | + int root = find_root(parent.at(i)); |
| 38 | + cnt.at(i) = cnt.at(root); |
| 39 | + parent.at(i) = root; |
41 | 40 | return root;
|
42 | 41 | }
|
43 | 42 | return i;
|
44 | 43 | }
|
45 | 44 |
|
46 |
| - bool union_vertices(int a,int b) |
47 |
| - { |
48 |
| - int a_root=find_root(a); |
49 |
| - int b_root=find_root(b); |
50 |
| - if (a_root==b_root) |
51 |
| - return false; |
52 |
| - if (rank.at(a_root)<rank.at(b_root)) |
53 |
| - parent.at(a_root)=b_root; |
54 |
| - else |
55 |
| - { |
56 |
| - parent.at(b_root)=a_root; |
57 |
| - if (rank.at(a_root)==rank.at(b_root)) |
| 45 | + int union_vertices(int a, int b) { |
| 46 | + int a_root = find_root(a); |
| 47 | + int b_root = find_root(b); |
| 48 | + if (a_root == b_root) |
| 49 | + return cnt.at(a_root); |
| 50 | + if (rank.at(a_root) < rank.at(b_root)) { |
| 51 | + parent.at(a_root) = b_root; |
| 52 | + cnt.at(b_root) += cnt.at(a_root); |
| 53 | + return cnt.at(b_root); |
| 54 | + } else { |
| 55 | + parent.at(b_root) = a_root; |
| 56 | + cnt.at(a_root) += cnt.at(b_root); |
| 57 | + if (rank.at(a_root) == rank.at(b_root)) |
58 | 58 | ++rank.at(a_root);
|
| 59 | + return cnt.at(a_root); |
59 | 60 | }
|
60 |
| - return true; |
| 61 | + return -1; |
| 62 | + } |
| 63 | + |
| 64 | + bool is_connected(int a, int b) { |
| 65 | + return find_root(a) == find_root(b); |
| 66 | + } |
| 67 | + |
| 68 | + int get_cnt(int i) { |
| 69 | + int root = find_root(i); |
| 70 | + return cnt.at(root); |
61 | 71 | }
|
62 | 72 | };
|
63 | 73 |
|
|
0 commit comments