#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int maxn = 50005;
struct node {
int l, r, id, part;
bool operator < (const node &tmp) const {
if (part == tmp.part) {
return part & 1 ? r < tmp.r : r > tmp.r; //奇偶优化
}
return part < tmp.part;
}
} p[maxn];
int a[maxn];
ll ans[maxn], res;
void update(int pos, int add) {
//随题意
}
int main() {
int n, m, k; scanf("%d%d%d", &n, &m, &k);
for (int i = 1; i <= n; i++) scanf("%d", a + i);
int block = sqrt(n);
for (int i = 1; i <= m; i++) {
scanf("%d%d", &p[i].l, &p[i].r);
p[i].id = i;
p[i].part = (p[i].l - 1) / block + 1;
}
sort(p + 1, p + m + 1);
int l = 1, r = 1;
res = 0; //也可能不是0
for (int i = 1; i <= m; i++) {
while (l > p[i].l) update(--l, 1);
while (r < p[i].r) update(++r, 1);
while (l < p[i].l) update(l++, -1);
while (r > p[i].r) update(r--, -1);
ans[p[i].id] = res;
}
for (int i = 1; i <= m; i++) {
printf("%lld\n", ans[i]);
}
return 0;
}
-
询问为
$(l,r,t)$ ,其中$t$ 为修改次数 -
先按
$l$ 分块,再按$r$ 分块,最后按$t$ 排序,$N$ 为序列长度,$M$ 为总修改次数,推荐块长有$\sqrt{N}$ ,$N^{\frac{2}{3}}$,$\sqrt[3]{N\times M}$。 -
块长为
$\sqrt[3]{N\times M}$ 时,理论复杂度最优,为$N^{\frac{4}{3}}\times M^{\frac{1}{3}}$ ,另外$(10^5)^{\frac{1}{3}}\approx46$
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> pi;
typedef long long ll;
const int maxn = 140005;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-7;
struct Query {
int l, r, lp, rp, id, t;
bool operator < (const Query &tmp) const {
if (lp != tmp.lp) return lp < tmp.lp;
if (rp != tmp.rp) return rp < tmp.rp;
return t < tmp.t;
}
};
struct Update {
int i, old, now;
};
int a[maxn];
int ans[maxn];
int cnt[1000006];
int res;
void add_col(int c) {
if (cnt[c] == 0) res++;
cnt[c]++;
}
void del_col(int c) {
if (cnt[c] == 1) res--;
cnt[c]--;
}
void add(int i) {
add_col(a[i]);
}
void del(int i) {
del_col(a[i]);
}
void add_upd(Update upd, int l, int r) {
if (upd.i >= l && upd.i <= r) {
del_col(upd.old);
add_col(upd.now);
}
assert(a[upd.i] == upd.old);
a[upd.i] = upd.now;
}
void del_upd(Update upd, int l, int r) {
if (upd.i >= l && upd.i <= r) {
del_col(upd.now);
add_col(upd.old);
}
assert(a[upd.i] == upd.now);
a[upd.i] = upd.old;
}
void run() {
int n, q;
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
}
vector<Query> query;
vector<Update> update;
while (q--) {
char op[5];
int x, y;
scanf("%s%d%d", op, &x, &y);
if (op[0] == 'Q') {
Query qi{x, y};
qi.id = query.size();
qi.t = update.size();
query.push_back(qi);
} else {
update.push_back({x, a[x], y});
a[x] = y;
}
}
int block = pow(n, 0.333) * max(1.0, pow(update.size(), 0.333));
// int block = ceil(exp((log(n) + max(1.0, log(update.size()))) / 3));
// int block = pow(n, 0.666);
// int block = sqrt(n);
for (Query &qi : query) {
qi.lp = qi.l / block;
qi.rp = qi.r / block;
}
sort(query.begin(), query.end());
int cur_t = update.size();
int l = 1, r = 1;
cnt[a[1]]++; res = 1;
for (Query qi : query) {
while (l > qi.l) add(--l);
while (r < qi.r) add(++r);
while (l < qi.l) del(l++);
while (r > qi.r) del(r--);
while (cur_t > qi.t) del_upd(update[--cur_t], l, r);
while (cur_t < qi.t) add_upd(update[cur_t++], l, r);
ans[qi.id] = res;
}
for (int i = 0; i < query.size(); i++) {
printf("%d\n", ans[i]);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
run();
}
return 0;
}
用欧拉序将树转化成线性结构
每次询问路径
- 若
$lca(x,y)==x$ ,统计区间$L[x]$ 到$L[y]$ 中只出现过一次的点 - 若
$lca(x,y)\neq x$ ,统计区间$R[x]$ 到$L[y]$ 中只出现过一次的点,外加$lca(x,y)$ 本身
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> pi;
typedef long long ll;
const int maxn = 40004;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
const double eps = 1e-7;
int col[maxn];
vector<int> G[maxn];
int d[maxn], fa[maxn];
int ola[maxn * 2], L[maxn], R[maxn], tot;
void dfs(int u, int depth) {
ola[++tot] = u;
L[u] = tot;
d[u] = depth;
for (int v : G[u]) {
if (v == fa[u]) continue;
fa[v] = u;
dfs(v, depth + 1);
}
ola[++tot] = u;
R[u] = tot;
}
int st[maxn][20];
void initST(int n) {
st[0][0] = -1;
for (int i = 1; i <= n; i++) st[i][0] = fa[i];
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 0; i <= n; i++) {
if (st[i][j - 1] == -1) st[i][j] = -1;
else st[i][j] = st[st[i][j - 1]][j - 1];
}
}
}
int LCA(int u, int v) {
if (d[u] < d[v]) swap(u, v);
for (int i = 0; d[u] != d[v]; i++) {
if ((d[u] - d[v]) >> i & 1) {
u = st[u][i];
}
}
if (u == v) return u;
for (int i = 19; i >= 0; i--) {
if (st[u][i] != -1 && st[v][i] != -1 && st[u][i] != st[v][i]) {
u = st[u][i];
v = st[v][i];
}
}
return st[u][0];
}
int block;
struct Query {
int l, r, part, id, lca;
Query() {}
Query(int u, int v, int _id) {
id = _id;
if (L[u] > L[v]) swap(u, v);
int _lca = LCA(u, v);
if (_lca == u) {
l = L[u];
r = L[v];
part = (l - 1) / block + 1;
lca = 0;
} else {
l = R[u];
r = L[v];
part = (l - 1) / block + 1;
lca = _lca;
}
}
bool operator < (const Query &tmp) const {
if (part == tmp.part) {
return part & 1 ? r < tmp.r : r > tmp.r;
} else {
return part < tmp.part;
}
}
} query[100005];
int cnt[maxn], cnt_col[maxn];
int cur;
void add_col(int c) {
if (cnt_col[c] == 0) {
cur++;
}
cnt_col[c]++;
}
void del_col(int c) {
if (cnt_col[c] == 1) {
cur--;
}
cnt_col[c]--;
}
void add(int i) {
int u = ola[i];
cnt[u]++;
if (cnt[u] == 1) {
add_col(col[u]);
} else if (cnt[u] == 2) {
del_col(col[u]);
}
}
void del(int i) {
int u = ola[i];
cnt[u]--;
if (cnt[u] == 1) {
add_col(col[u]);
} else if (cnt[u] == 0) {
del_col(col[u]);
}
}
int ans[100005];
void run() {
int n, q;
scanf("%d%d", &n, &q);
set<int> s;
map<int, int> mp;
for (int i = 1; i <= n; i++) {
scanf("%d", col + i);
s.insert(col[i]);
}
int rak = 0;
for (int i : s) {
mp[i] = ++rak;
}
for (int i = 1; i <= n; i++) {
col[i] = mp[col[i]];
}
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
fa[1] = -1;
dfs(1, 0);
initST(n);
block = sqrt(n * 2);
for (int i = 1; i <= q; i++) {
int u, v;
scanf("%d%d", &u, &v);
query[i] = Query(u, v, i);
}
sort(query + 1, query + 1 + q);
int l = 1, r = 1;
cnt[ola[1]]++;
cnt_col[col[ola[1]]]++;
cur = 1;
for (int i = 1; i <= q; i++) {
while (l > query[i].l) add(--l);
while (r < query[i].r) add(++r);
while (l < query[i].l) del(l++);
while (r > query[i].r) del(r--);
int tmp = 0;
if (query[i].lca && cnt_col[col[query[i].lca]] == 0) {
tmp = 1;
}
ans[query[i].id] = cur + tmp;
}
for (int i = 1; i <= q; i++) {
printf("%d\n", ans[i]);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
run();
}
return 0;
}