From 04af3744d528aeefcd61be2eb72ccb8998aac01c Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Mon, 18 Jan 2021 15:51:25 +0900 Subject: [PATCH] #95: fix convolution --- .gitmodules | 3 +++ atcoder/convolution.hpp | 44 +++++++++++++++++++++++++--------- test/benchmark/CMakeLists.txt | 23 ++++++++++++++++++ test/benchmark/benchmark | 1 + test/benchmark/convolution.cpp | 35 +++++++++++++++++++++++++++ 5 files changed, 95 insertions(+), 11 deletions(-) create mode 100644 test/benchmark/CMakeLists.txt create mode 160000 test/benchmark/benchmark create mode 100644 test/benchmark/convolution.cpp diff --git a/.gitmodules b/.gitmodules index afa2247..a83a0aa 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "test/unittest/googletest"] path = test/unittest/googletest url = https://github.com/google/googletest +[submodule "test/benchmark/benchmark"] + path = test/benchmark/benchmark + url = https://github.com/google/benchmark diff --git a/atcoder/convolution.hpp b/atcoder/convolution.hpp index 166e356..9bb5682 100644 --- a/atcoder/convolution.hpp +++ b/atcoder/convolution.hpp @@ -101,25 +101,29 @@ void butterfly_inv(std::vector& a) { } } -} // namespace internal - template * = nullptr> -std::vector convolution(std::vector a, std::vector b) { +std::vector convolution_naive(const std::vector& a, const std::vector& b) { int n = int(a.size()), m = int(b.size()); - if (!n || !m) return {}; - if (std::min(n, m) <= 60) { - if (n < m) { - std::swap(n, m); - std::swap(a, b); + std::vector ans(n + m - 1); + if (n < m) { + for (int j = 0; j < m; j++) { + for (int i = 0; i < n; i++) { + ans[i + j] += a[i] * b[j]; + } } - std::vector ans(n + m - 1); + } else { for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { ans[i + j] += a[i] * b[j]; } } - return ans; } + return ans; +} + +template * = nullptr> +std::vector convolution_fft(std::vector a, std::vector b) { + int n = int(a.size()), m = int(b.size()); int z = 1 << internal::ceil_pow2(n + m - 1); a.resize(z); internal::butterfly(a); @@ -132,7 +136,25 @@ std::vector convolution(std::vector a, std::vector b) { a.resize(n + m - 1); mint iz = mint(z).inv(); for (int i = 0; i < n + m - 1; i++) a[i] *= iz; - return a; + return std::move(a); +} + +} // namespace internal + +template * = nullptr> +std::vector convolution(std::vector&& a, std::vector&& b) { + int n = int(a.size()), m = int(b.size()); + if (!n || !m) return {}; + if (std::min(n, m) <= 60) return convolution_naive(a, b); + return internal::convolution_fft(a, b); +} + +template * = nullptr> +std::vector convolution(const std::vector& a, const std::vector& b) { + int n = int(a.size()), m = int(b.size()); + if (!n || !m) return {}; + if (std::min(n, m) <= 60) return convolution_naive(a, b); + return internal::convolution_fft(a, b); } template + +#include "benchmark/benchmark.h" + +using namespace std; +using namespace atcoder; +using mint = modint998244353; + +void CONV_same_length(benchmark::State& state) { + vector a(state.range(0)), b(state.range(0)); + for (int i = 0; i < state.range(0); i++) { + a[i] = i + 1234; + b[i] = i + 5678; + } + for (auto _ : state) { + benchmark::DoNotOptimize(convolution(a, b)); + } +} +BENCHMARK(CONV_same_length)->RangeMultiplier(2)->Range(1, 1<<20); +BENCHMARK(CONV_same_length)->DenseRange(1, 100, 1); + +void CONV_long_empty(benchmark::State& state) { + vector a(state.range(0)), b; + for (int i = 0; i < state.range(0); i++) { + a[i] = i + 1234; + } + for (auto _ : state) { + benchmark::DoNotOptimize(convolution(a, b)); + benchmark::DoNotOptimize(convolution(b, a)); + } +} +BENCHMARK(CONV_long_empty)->RangeMultiplier(2)->Range(1, 1 << 20); + +BENCHMARK_MAIN();