-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathhadamard.hpp
59 lines (54 loc) · 2.03 KB
/
hadamard.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#pragma once
#include <cassert>
#include <vector>
// Fast Walsh-Hadamard transform and its abstraction
// Tutorials: <https://codeforces.com/blog/entry/71899>
// <https://csacademy.com/blog/fast-fourier-transform-and-variations-of-it>
template <typename T, typename F> void abstract_fwht(std::vector<T> &seq, F f) {
const int n = seq.size();
assert(__builtin_popcount(n) == 1);
for (int w = 1; w < n; w *= 2) {
for (int i = 0; i < n; i += w * 2) {
for (int j = 0; j < w; j++) f(seq.at(i + j), seq.at(i + j + w));
}
}
}
template <typename T, typename F1, typename F2>
std::vector<T> bitwise_conv(std::vector<T> x, std::vector<T> y, F1 f, F2 finv) {
const int n = x.size();
assert(__builtin_popcount(n) == 1);
assert(x.size() == y.size());
if (x == y) {
abstract_fwht(x, f), y = x;
} else {
abstract_fwht(x, f), abstract_fwht(y, f);
}
for (int i = 0; i < (int)x.size(); i++) x.at(i) *= y.at(i);
abstract_fwht(x, finv);
return x;
}
// bitwise xor convolution (FWHT-based)
// ret[i] = \sum_j x[j] * y[i ^ j]
// if T is integer, ||x||_1 * ||y||_1 * 2 < numeric_limits<T>::max()
template <typename T> std::vector<T> xorconv(std::vector<T> x, std::vector<T> y) {
auto f = [](T &lo, T &hi) {
T c = lo + hi;
hi = lo - hi, lo = c;
};
auto finv = [](T &lo, T &hi) {
T c = lo + hi;
hi = (lo - hi) / 2,
lo = c / 2; // Reconsider HEAVY complexity of division by 2 when T is ModInt
};
return bitwise_conv(x, y, f, finv);
}
// bitwise AND conolution
// ret[i] = \sum_{(j & k) == i} x[j] * y[k]
template <typename T> std::vector<T> andconv(std::vector<T> x, std::vector<T> y) {
return bitwise_conv(x, y, [](T &lo, T &hi) { lo += hi; }, [](T &lo, T &hi) { lo -= hi; });
}
// bitwise OR convolution
// ret[i] = \sum_{(j | k) == i} x[j] * y[k]
template <typename T> std::vector<T> orconv(std::vector<T> x, std::vector<T> y) {
return bitwise_conv(x, y, [](T &lo, T &hi) { hi += lo; }, [](T &lo, T &hi) { hi -= lo; });
}