From 7198637192ec61d775c82912cd8525f44be102b2 Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Thu, 24 Jun 2021 03:27:48 +0900 Subject: [PATCH 1/7] add test for modint / convolution --- test/unittest/convolution_test.cpp | 31 ++++++++++++++++++++++++++++++ test/unittest/modint_test.cpp | 30 +++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/test/unittest/convolution_test.cpp b/test/unittest/convolution_test.cpp index cf8f55a..9005e0a 100644 --- a/test/unittest/convolution_test.cpp +++ b/test/unittest/convolution_test.cpp @@ -385,3 +385,34 @@ TEST(ConvolutionTest, Conv18433) { ASSERT_EQ(conv_naive(a, b), convolution(a, b)); } + +TEST(ConvolutionTest, Conv2) { + std::vector empty = {}; + ASSERT_EQ(empty, convolution<2>(empty, empty)); +} + +TEST(ConvolutionTest, Conv2147483647) { + const int MOD = 2147483647; + using mint = static_modint; + std::vector a(1), b(2); + for (int i = 0; i < 1; i++) { + a[i] = randint(0, MOD - 1); + } + for (int i = 0; i < 2; i++) { + b[i] = randint(0, MOD - 1); + } + ASSERT_EQ(conv_naive(a, b), convolution(a, b)); +} + +TEST(ConvolutionTest, Conv2130706433) { + const int MOD = 2130706433; + using mint = static_modint; + std::vector a(1024), b(1024); + for (int i = 0; i < 1024; i++) { + a[i] = randint(0, MOD - 1); + } + for (int i = 0; i < 1024; i++) { + b[i] = randint(0, MOD - 1); + } + ASSERT_EQ(conv_naive(a, b), convolution(a, b)); +} diff --git a/test/unittest/modint_test.cpp b/test/unittest/modint_test.cpp index 81f33cc..222c1bd 100644 --- a/test/unittest/modint_test.cpp +++ b/test/unittest/modint_test.cpp @@ -100,6 +100,29 @@ TEST(ModintTest, Mod1) { ASSERT_EQ(0, mint(true).val()); } +TEST(ModintTest, ModIntMax) { + modint::set_mod(INT32_MAX); + for (int i = 0; i < 100; i++) { + for (int j = 0; j < 100; j++) { + ASSERT_EQ((modint(i) * modint(j)).val(), i * j); + } + } + ASSERT_EQ((modint(1234) + modint(5678)).val(), 1234 + 5678); + ASSERT_EQ((modint(1234) - modint(5678)).val(), INT32_MAX - 5678 + 1234); + ASSERT_EQ((modint(1234) * modint(5678)).val(), 1234 * 5678); + + using mint = static_modint; + for (int i = 0; i < 100; i++) { + for (int j = 0; j < 100; j++) { + ASSERT_EQ((mint(i) * mint(j)).val(), i * j); + } + } + ASSERT_EQ((mint(1234) + mint(5678)).val(), 1234 + 5678); + ASSERT_EQ((mint(1234) - mint(5678)).val(), INT32_MAX - 5678 + 1234); + ASSERT_EQ((mint(1234) * mint(5678)).val(), 1234 * 5678); + ASSERT_EQ((mint(INT32_MAX) + mint(INT32_MAX)).val(), 0); +} + #ifndef _MSC_VER TEST(ModintTest, Int128) { @@ -158,6 +181,13 @@ TEST(ModintTest, Inv) { int x = modint(i).inv().val(); ASSERT_EQ(1, (ll(x) * i) % 1'000'000'008); } + + modint::set_mod(INT32_MAX); + for (int i = 1; i < 100000; i++) { + if (gcd(i, INT32_MAX) != 1) continue; + int x = modint(i).inv().val(); + ASSERT_EQ(1, (ll(x) * i) % INT32_MAX); + } } TEST(ModintTest, ConstUsage) { From 31f91384f88b42f675232116374007902029cb97 Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Thu, 24 Jun 2021 03:28:09 +0900 Subject: [PATCH 2/7] optimize fft --- atcoder/convolution.hpp | 207 ++++++++++++++++++++++++++------------- atcoder/internal_bit.hpp | 8 ++ 2 files changed, 147 insertions(+), 68 deletions(-) diff --git a/atcoder/convolution.hpp b/atcoder/convolution.hpp index c01956c..9e84680 100644 --- a/atcoder/convolution.hpp +++ b/atcoder/convolution.hpp @@ -14,95 +14,165 @@ namespace atcoder { namespace internal { +template , + internal::is_static_modint_t* = nullptr> +struct fft_info { + static constexpr int rank2 = bsf_constexpr(mint::mod() - 1); + std::array root; // root[i]^(2^i) == 1 + std::array iroot; // root[i] * iroot[i] == 1 + + std::array rate2; + std::array irate2; + + std::array rate3; + std::array irate3; + + fft_info() { + root[rank2] = mint(g).pow((mint::mod() - 1) >> rank2); + iroot[rank2] = root[rank2].inv(); + for (int i = rank2 - 1; i >= 0; i--) { + root[i] = root[i + 1] * root[i + 1]; + iroot[i] = iroot[i + 1] * iroot[i + 1]; + } + + { + mint prod = 1, iprod = 1; + for (int i = 0; i <= rank2 - 2; i++) { + rate2[i] = root[i + 2] * prod; + irate2[i] = iroot[i + 2] * iprod; + prod *= iroot[i + 2]; + iprod *= root[i + 2]; + } + } + { + mint prod = 1, iprod = 1; + for (int i = 0; i <= rank2 - 3; i++) { + rate3[i] = root[i + 3] * prod; + irate3[i] = iroot[i + 3] * iprod; + prod *= iroot[i + 3]; + iprod *= root[i + 3]; + } + } + } +}; + template * = nullptr> void butterfly(std::vector& a) { - static constexpr int g = internal::primitive_root; int n = int(a.size()); int h = internal::ceil_pow2(n); - static bool first = true; - static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i] - if (first) { - first = false; - mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1 - int cnt2 = bsf(mint::mod() - 1); - mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv(); - for (int i = cnt2; i >= 2; i--) { - // e^(2^i) == 1 - es[i - 2] = e; - ies[i - 2] = ie; - e *= e; - ie *= ie; - } - mint now = 1; - for (int i = 0; i <= cnt2 - 2; i++) { - sum_e[i] = es[i] * now; - now *= ies[i]; - } - } - for (int ph = 1; ph <= h; ph++) { - int w = 1 << (ph - 1), p = 1 << (h - ph); - mint now = 1; - for (int s = 0; s < w; s++) { - int offset = s << (h - ph + 1); - for (int i = 0; i < p; i++) { - auto l = a[i + offset]; - auto r = a[i + offset + p] * now; - a[i + offset] = l + r; - a[i + offset + p] = l - r; + static const fft_info info; + + int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed + while (len < h) { + if (h - len == 1) { + int p = 1 << (h - len - 1); + mint rot = 1; + for (int s = 0; s < (1 << len); s++) { + int offset = s << (h - len); + for (int i = 0; i < p; i++) { + auto l = a[i + offset]; + auto r = a[i + offset + p] * rot; + a[i + offset] = l + r; + a[i + offset + p] = l - r; + } + rot *= info.rate2[bsf(~(unsigned int)(s))]; + } + len++; + } else { + // 4-base + int p = 1 << (h - len - 2); + mint rot = 1, imag = info.root[2]; + for (int s = 0; s < (1 << len); s++) { + mint rot2 = rot * rot; + mint rot3 = rot2 * rot; + int offset = s << (h - len); + for (int i = 0; i < p; i++) { + auto mod2 = 1ULL * mint::mod() * mint::mod(); + auto a0 = 1ULL * a[i + offset].val(); + auto a1 = 1ULL * a[i + offset + p].val() * rot.val(); + auto a2 = 1ULL * a[i + offset + 2 * p].val() * rot2.val(); + auto a3 = 1ULL * a[i + offset + 3 * p].val() * rot3.val(); + auto a1na3imag = + 1ULL * mint(a1 + mod2 - a3).val() * imag.val(); + auto na2 = mod2 - a2; + a[i + offset] = a0 + a2 + a1 + a3; + a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3)); + a[i + offset + 2 * p] = a0 + na2 + a1na3imag; + a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag); + } + rot *= info.rate3[bsf(~(unsigned int)(s))]; } - now *= sum_e[bsf(~(unsigned int)(s))]; + len += 2; } } } template * = nullptr> void butterfly_inv(std::vector& a) { - static constexpr int g = internal::primitive_root; int n = int(a.size()); int h = internal::ceil_pow2(n); - static bool first = true; - static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i] - if (first) { - first = false; - mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1 - int cnt2 = bsf(mint::mod() - 1); - mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv(); - for (int i = cnt2; i >= 2; i--) { - // e^(2^i) == 1 - es[i - 2] = e; - ies[i - 2] = ie; - e *= e; - ie *= ie; - } - mint now = 1; - for (int i = 0; i <= cnt2 - 2; i++) { - sum_ie[i] = ies[i] * now; - now *= es[i]; - } - } + static const fft_info info; + + int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed + while (len) { + if (len == 1) { + int p = 1 << (h - len); + mint irot = 1; + for (int s = 0; s < (1 << (len - 1)); s++) { + int offset = s << (h - len + 1); + for (int i = 0; i < p; i++) { + auto l = a[i + offset]; + auto r = a[i + offset + p]; + a[i + offset] = l + r; + a[i + offset + p] = + (unsigned long long)(mint::mod() + l.val() - r.val()) * + irot.val(); + ; + } + irot *= info.irate2[bsf(~(unsigned int)(s))]; + } + len--; + } else { + // 4-base + int p = 1 << (h - len); + mint irot = 1, iimag = info.iroot[2]; + for (int s = 0; s < (1 << (len - 2)); s++) { + mint irot2 = irot * irot; + mint irot3 = irot2 * irot; + int offset = s << (h - len + 2); + for (int i = 0; i < p; i++) { + auto a0 = 1ULL * a[i + offset + 0 * p].val(); + auto a1 = 1ULL * a[i + offset + 1 * p].val(); + auto a2 = 1ULL * a[i + offset + 2 * p].val(); + auto a3 = 1ULL * a[i + offset + 3 * p].val(); + + auto a2na3iimag = + 1ULL * + mint((mint::mod() + a2 - a3) * iimag.val()).val(); - for (int ph = h; ph >= 1; ph--) { - int w = 1 << (ph - 1), p = 1 << (h - ph); - mint inow = 1; - for (int s = 0; s < w; s++) { - int offset = s << (h - ph + 1); - for (int i = 0; i < p; i++) { - auto l = a[i + offset]; - auto r = a[i + offset + p]; - a[i + offset] = l + r; - a[i + offset + p] = - (unsigned long long)(mint::mod() + l.val() - r.val()) * - inow.val(); + a[i + offset] = a0 + a1 + a2 + a3; + a[i + offset + 1 * p] = + (a0 + (mint::mod() - a1) + a2na3iimag) * irot.val(); + a[i + offset + 2 * p] = + (a0 + a1 + (mint::mod() - a2) + (mint::mod() - a3)) * + irot2.val(); + a[i + offset + 3 * p] = + (a0 + (mint::mod() - a1) + (mint::mod() - a2na3iimag)) * + irot3.val(); + } + irot *= info.irate3[bsf(~(unsigned int)(s))]; } - inow *= sum_ie[bsf(~(unsigned int)(s))]; + len -= 2; } } } template * = nullptr> -std::vector convolution_naive(const std::vector& a, const std::vector& b) { +std::vector convolution_naive(const std::vector& a, + const std::vector& b) { int n = int(a.size()), m = int(b.size()); std::vector ans(n + m - 1); if (n < m) { @@ -150,7 +220,8 @@ std::vector convolution(std::vector&& a, std::vector&& b) { } template * = nullptr> -std::vector convolution(const std::vector& a, const std::vector& b) { +std::vector convolution(const std::vector& a, + const std::vector& b) { int n = int(a.size()), m = int(b.size()); if (!n || !m) return {}; if (std::min(n, m) <= 60) return convolution_naive(a, b); diff --git a/atcoder/internal_bit.hpp b/atcoder/internal_bit.hpp index d219b0f..ada311a 100644 --- a/atcoder/internal_bit.hpp +++ b/atcoder/internal_bit.hpp @@ -17,6 +17,14 @@ int ceil_pow2(int n) { return x; } +// @param n `1 <= n` +// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0` +constexpr int bsf_constexpr(unsigned int n) { + int x = 0; + while (!(n & (1 << x))) x++; + return x; +} + // @param n `1 <= n` // @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0` int bsf(unsigned int n) { From fdf47809ac995dd213ffe0a8bad1bf5b06744dea Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Fri, 25 Jun 2021 00:36:56 +0900 Subject: [PATCH 3/7] add more test --- test/unittest/convolution_test.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/unittest/convolution_test.cpp b/test/unittest/convolution_test.cpp index 9005e0a..d1648f9 100644 --- a/test/unittest/convolution_test.cpp +++ b/test/unittest/convolution_test.cpp @@ -391,6 +391,19 @@ TEST(ConvolutionTest, Conv2) { ASSERT_EQ(empty, convolution<2>(empty, empty)); } +TEST(ConvolutionTest, Conv257) { + const int MOD = 257; + std::vector a(128), b(129); + for (int i = 0; i < 128; i++) { + a[i] = randint(0, MOD - 1); + } + for (int i = 0; i < 129; i++) { + b[i] = randint(0, MOD - 1); + } + + ASSERT_EQ(conv_naive(a, b), convolution(a, b)); +} + TEST(ConvolutionTest, Conv2147483647) { const int MOD = 2147483647; using mint = static_modint; From adb894dcfc58cff20bcf11fa2e00128175bb2c41 Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Fri, 25 Jun 2021 00:37:08 +0900 Subject: [PATCH 4/7] fix code for avoid out of range --- atcoder/convolution.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/atcoder/convolution.hpp b/atcoder/convolution.hpp index 9e84680..9711424 100644 --- a/atcoder/convolution.hpp +++ b/atcoder/convolution.hpp @@ -22,11 +22,11 @@ struct fft_info { std::array root; // root[i]^(2^i) == 1 std::array iroot; // root[i] * iroot[i] == 1 - std::array rate2; - std::array irate2; + std::array rate2; + std::array irate2; - std::array rate3; - std::array irate3; + std::array rate3; + std::array irate3; fft_info() { root[rank2] = mint(g).pow((mint::mod() - 1) >> rank2); From d6d3c8ca168b3655db34f4549d9abd27799e830d Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Mon, 19 Jul 2021 14:23:55 +0900 Subject: [PATCH 5/7] bit refactoring --- atcoder/convolution.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/atcoder/convolution.hpp b/atcoder/convolution.hpp index 9711424..0362a5e 100644 --- a/atcoder/convolution.hpp +++ b/atcoder/convolution.hpp @@ -22,11 +22,11 @@ struct fft_info { std::array root; // root[i]^(2^i) == 1 std::array iroot; // root[i] * iroot[i] == 1 - std::array rate2; - std::array irate2; + std::array rate2; + std::array irate2; - std::array rate3; - std::array irate3; + std::array rate3; + std::array irate3; fft_info() { root[rank2] = mint(g).pow((mint::mod() - 1) >> rank2); @@ -77,7 +77,7 @@ void butterfly(std::vector& a) { a[i + offset] = l + r; a[i + offset + p] = l - r; } - rot *= info.rate2[bsf(~(unsigned int)(s))]; + if (s + 1 != (1 << len)) rot *= info.rate2[bsf(~(unsigned int)(s))]; } len++; } else { @@ -102,7 +102,7 @@ void butterfly(std::vector& a) { a[i + offset + 2 * p] = a0 + na2 + a1na3imag; a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag); } - rot *= info.rate3[bsf(~(unsigned int)(s))]; + if (s + 1 != (1 << len)) rot *= info.rate3[bsf(~(unsigned int)(s))]; } len += 2; } From 24b7295e59b5d56ea76e034e99d22dc807c35efc Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Mon, 19 Jul 2021 14:28:32 +0900 Subject: [PATCH 6/7] fix --- atcoder/convolution.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/atcoder/convolution.hpp b/atcoder/convolution.hpp index 0362a5e..d7bc2ac 100644 --- a/atcoder/convolution.hpp +++ b/atcoder/convolution.hpp @@ -132,7 +132,7 @@ void butterfly_inv(std::vector& a) { irot.val(); ; } - irot *= info.irate2[bsf(~(unsigned int)(s))]; + if (s + 1 != (1 << len)) irot *= info.irate2[bsf(~(unsigned int)(s))]; } len--; } else { @@ -163,7 +163,7 @@ void butterfly_inv(std::vector& a) { (a0 + (mint::mod() - a1) + (mint::mod() - a2na3iimag)) * irot3.val(); } - irot *= info.irate3[bsf(~(unsigned int)(s))]; + if (s + 1 != (1 << len)) irot *= info.irate3[bsf(~(unsigned int)(s))]; } len -= 2; } From f841a342442f99a8df045e2d49696b9b3e21ea03 Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Mon, 19 Jul 2021 14:38:53 +0900 Subject: [PATCH 7/7] fix --- atcoder/convolution.hpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/atcoder/convolution.hpp b/atcoder/convolution.hpp index d7bc2ac..ce3d272 100644 --- a/atcoder/convolution.hpp +++ b/atcoder/convolution.hpp @@ -77,7 +77,8 @@ void butterfly(std::vector& a) { a[i + offset] = l + r; a[i + offset + p] = l - r; } - if (s + 1 != (1 << len)) rot *= info.rate2[bsf(~(unsigned int)(s))]; + if (s + 1 != (1 << len)) + rot *= info.rate2[bsf(~(unsigned int)(s))]; } len++; } else { @@ -102,7 +103,8 @@ void butterfly(std::vector& a) { a[i + offset + 2 * p] = a0 + na2 + a1na3imag; a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag); } - if (s + 1 != (1 << len)) rot *= info.rate3[bsf(~(unsigned int)(s))]; + if (s + 1 != (1 << len)) + rot *= info.rate3[bsf(~(unsigned int)(s))]; } len += 2; } @@ -132,7 +134,8 @@ void butterfly_inv(std::vector& a) { irot.val(); ; } - if (s + 1 != (1 << len)) irot *= info.irate2[bsf(~(unsigned int)(s))]; + if (s + 1 != (1 << (len - 1))) + irot *= info.irate2[bsf(~(unsigned int)(s))]; } len--; } else { @@ -163,7 +166,8 @@ void butterfly_inv(std::vector& a) { (a0 + (mint::mod() - a1) + (mint::mod() - a2na3iimag)) * irot3.val(); } - if (s + 1 != (1 << len)) irot *= info.irate3[bsf(~(unsigned int)(s))]; + if (s + 1 != (1 << (len - 2))) + irot *= info.irate3[bsf(~(unsigned int)(s))]; } len -= 2; }