From 193fe24a8903517c5fa2e2117f372da448be2ce4 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 20 May 2025 22:51:37 -0700 Subject: [PATCH 1/6] Modify the test cases Signed-off-by: Przemek Tredak --- tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu | 2 +- tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu | 1 - tests/pytorch/distributed/run_numerics.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 0d49ae17fe..2a192e1027 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -376,7 +376,7 @@ std::vector> matrix_sizes = { {993, 512}, {768, 1024}, {65536, 128}, - {16384, 1632}, + {16384, 1504}, }; std::vector> block_sizes = { diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index 15744fbeea..0c3d4a2fb6 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -143,7 +143,6 @@ void performTest(const size_t N, const size_t H) { std::vector> test_cases = {{64, 400}, {2048, 12288}, {768, 1024}, - {256, 65536}, {65536, 128}, {256, 256}}; diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 61dce2c5ec..c1edb74b17 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -185,7 +185,7 @@ def _get_tolerances(dtype): if dtype == torch.bfloat16: return {"rtol": 1.6e-2, "atol": 1e-5} if dtype == torch.float32: - return {"rtol": 1.3e-6, "atol": 4e-5} + return {"rtol": 1e-4, "atol": 1e-4} raise ValueError(f"Unsupported dtype ({dtype})") From 96ffb8a25f557c3b75ffcf598089b820aa29cab3 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 22 May 2025 14:24:43 -0700 Subject: [PATCH 2/6] Make the tests reproducible on different machines Signed-off-by: Przemek Tredak --- tests/cpp/test_common.cu | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 96ff39eaad..ad43f64c35 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -694,6 +694,17 @@ std::pair getTolerances(const DType type) { template void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { + // Check how many times the distribution will call the generator + int k = 0; + { + std::mt19937 gen1 = *gen, gen2 = *gen; + std::uniform_real_distribution<> dis(-2.0, 1.0); + const float _ = dis(gen1); + while (gen2 != gen1) { + auto _ = gen2(); + ++k; + } + } #pragma omp parallel proc_bind(spread) { std::mt19937 gen_local = *gen; @@ -702,14 +713,14 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { const int chunk_size = (size + threads_num - 1) / threads_num; const int idx_min = chunk_size * thread_ID; const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast(size)); - gen_local.discard(idx_min); + gen_local.discard(idx_min * k); std::uniform_real_distribution<> dis(-2.0, 1.0); for (int i = idx_min; i < idx_max; ++i) { data[i] = static_cast(dis(gen_local)); } } - gen->discard(size); + gen->discard(size * k); } void fillUniform(Tensor *t) { From 544d21957e28ff2e91aaa1568fc35c679c869fc7 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 22 May 2025 16:49:55 -0700 Subject: [PATCH 3/6] Fixed the cache of the gamma_in_weight_dtype setting Signed-off-by: Przemek Tredak --- .../common/normalization/common.cpp | 25 +++++++++++-------- .../common/normalization/common.h | 7 ++++-- .../common/normalization/layernorm/ln_api.cpp | 11 ++++++-- .../normalization/rmsnorm/rmsnorm_api.cpp | 12 +++++++-- 4 files changed, 38 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 89affc081c..2daeccac9c 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -39,8 +39,6 @@ Compute always in FP32 namespace transformer_engine { namespace normalization { -bool& use_zero_centered_gamma_in_weight_dtype(); - cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { return training ? cudnn_frontend::NormFwdPhase_t::TRAINING : cudnn_frontend::NormFwdPhase_t::INFERENCE; @@ -49,13 +47,17 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, - bool is_tuned, NVTEScalingMode mode, bool training) { - // TODO: Add scaling_mode to general_key is needed - uint64_t general_key = static_cast(itype) | (static_cast(otype) << 3) | - (static_cast(ctype) << 6) | (static_cast(wtype) << 9) | - (uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 | - (uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) | - (uint32_t(mode) << 19) | (uint32_t(training) << 22); + bool is_tuned, NVTEScalingMode mode, bool training, bool gamma_in_weight_dtype) { + static_assert(NVTE_INVALID_SCALING < 1024, + "This function assumes at most 10 bits used in the scaling mode."); + static_assert(kNVTENumTypes < 32, + "This function assumes at most 5 bits used in the NVTEDType"); + uint64_t general_key = static_cast(itype) | (static_cast(otype) << 5) | + (static_cast(ctype) << 10) | (static_cast(wtype) << 15) | + (uint64_t(NormType) << 20) | (uint64_t(NormStage)) << 22 | + (uint64_t(NormBackend) << 24) | (uint64_t(zero_centered_gamma) << 26) | + (uint64_t(mode) << 27) | (uint64_t(training) << 37) | + (uint64_t(gamma_in_weight_dtype) << 38); return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } @@ -466,11 +468,11 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, - const NVTEScalingMode mode, const bool training) { + const NVTEScalingMode mode, const bool training, const bool gamma_in_weight_dtype) { const DType ctype = DType::kFloat32; bool is_tuned = is_aligned && (batch_size % 4 == 0); auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, - hidden_size, zero_centered_gamma, is_tuned, mode, training); + hidden_size, zero_centered_gamma, is_tuned, mode, training, gamma_in_weight_dtype); auto it = normalizationPlanMap.find(key); if (it != normalizationPlanMap.end()) { @@ -528,6 +530,7 @@ void nvte_enable_cudnn_norm_bwd(bool enable) { transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; } +// Only for testing, not thread-safe void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable) { NVTE_API_CALL(nvte_enable_zero_centered_gamma_in_weight_dtype); transformer_engine::normalization::_zero_centered_gamma_in_weight_dtype() = enable; diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index d465bdd581..0ec16046e3 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -159,7 +159,7 @@ TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, - bool training = true); + bool training = true, bool gamma_in_weight_dtype = false); template class TeNormalizationRegistry { @@ -307,7 +307,8 @@ class NormalizationPlanRegistry { NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, - const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true); + const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true, + const bool gamma_in_weight_dtype = false); private: NormalizationPlanRegistry() {} @@ -381,6 +382,8 @@ bool is_ptr_aligned(const Args*... ptrs) { bool use_cudnn_norm_fwd(); bool use_cudnn_norm_bwd(); +bool& use_zero_centered_gamma_in_weight_dtype(); + } // namespace normalization } // namespace transformer_engine diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 47b37b3482..185ac6d457 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -15,6 +15,7 @@ #include "../../common.h" #include "../common.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine { @@ -64,9 +65,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size bool is_aligned = true; bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + bool gamma_in_weight_dtype = false; if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, @@ -83,7 +86,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); @@ -150,9 +154,11 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te NVTE_Norm_Backend norm_backend; bool is_aligned = true; + bool gamma_in_weight_dtype = false; if (use_cudnn_norm_bwd()) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, @@ -165,7 +171,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te gamma.data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, + NVTE_DELAYED_TENSOR_SCALING, true, gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 48cf1d819b..842f72065f 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -13,6 +13,7 @@ #include "../../common.h" #include "../common.h" #include "transformer_engine/normalization.h" +#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transpose.h" namespace transformer_engine { @@ -53,9 +54,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; + bool gamma_in_weight_dtype = false; if (cudnn_backend) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr); @@ -68,7 +71,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens z->data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); @@ -126,9 +130,11 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const NVTE_Norm_Backend norm_backend; bool is_aligned = true; + bool gamma_in_weight_dtype = false; if (use_cudnn_norm_bwd()) { // TODO: add check for GPU ARCH norm_backend = NVTE_Norm_Backend::Cudnn; + gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype(); } else { norm_backend = NVTE_Norm_Backend::Te; is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, @@ -141,7 +147,9 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const gamma.data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned); + multiprocessorCount, zero_centered_gamma, is_aligned, + NVTE_DELAYED_TENSOR_SCALING, true, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); From 5cdb363f714880642b0d6f3650ae1134bb6e5cde Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 22 May 2025 16:50:14 -0700 Subject: [PATCH 4/6] Reinstate the tests Signed-off-by: Przemek Tredak --- tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu | 4 ++-- tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 2a192e1027..2b22942f84 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -375,8 +375,8 @@ std::vector> matrix_sizes = { {256, 256}, {993, 512}, {768, 1024}, - {65536, 128}, - {16384, 1504}, + {65504, 128}, + {16384, 1632}, }; std::vector> block_sizes = { diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index 0c3d4a2fb6..15744fbeea 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -143,6 +143,7 @@ void performTest(const size_t N, const size_t H) { std::vector> test_cases = {{64, 400}, {2048, 12288}, {768, 1024}, + {256, 65536}, {65536, 128}, {256, 256}}; From 575379a00c2d896879f5b66f7d2a8b297f05ed81 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 May 2025 23:51:20 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/normalization/common.cpp | 21 ++++++++++--------- .../common/normalization/layernorm/ln_api.cpp | 4 ++-- .../normalization/rmsnorm/rmsnorm_api.cpp | 3 +-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 2daeccac9c..ae89c7773c 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -47,17 +47,17 @@ cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, - bool is_tuned, NVTEScalingMode mode, bool training, bool gamma_in_weight_dtype) { + bool is_tuned, NVTEScalingMode mode, bool training, + bool gamma_in_weight_dtype) { static_assert(NVTE_INVALID_SCALING < 1024, "This function assumes at most 10 bits used in the scaling mode."); - static_assert(kNVTENumTypes < 32, - "This function assumes at most 5 bits used in the NVTEDType"); + static_assert(kNVTENumTypes < 32, "This function assumes at most 5 bits used in the NVTEDType"); uint64_t general_key = static_cast(itype) | (static_cast(otype) << 5) | - (static_cast(ctype) << 10) | (static_cast(wtype) << 15) | - (uint64_t(NormType) << 20) | (uint64_t(NormStage)) << 22 | - (uint64_t(NormBackend) << 24) | (uint64_t(zero_centered_gamma) << 26) | - (uint64_t(mode) << 27) | (uint64_t(training) << 37) | - (uint64_t(gamma_in_weight_dtype) << 38); + (static_cast(ctype) << 10) | + (static_cast(wtype) << 15) | (uint64_t(NormType) << 20) | + (uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) | + (uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) | + (uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38); return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); } @@ -471,8 +471,9 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( const NVTEScalingMode mode, const bool training, const bool gamma_in_weight_dtype) { const DType ctype = DType::kFloat32; bool is_tuned = is_aligned && (batch_size % 4 == 0); - auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, - hidden_size, zero_centered_gamma, is_tuned, mode, training, gamma_in_weight_dtype); + auto key = + get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, + zero_centered_gamma, is_tuned, mode, training, gamma_in_weight_dtype); auto it = normalizationPlanMap.find(key); if (it != normalizationPlanMap.end()) { diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 185ac6d457..0025745257 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -171,8 +171,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te gamma.data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned, - NVTE_DELAYED_TENSOR_SCALING, true, gamma_in_weight_dtype); + multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, + gamma_in_weight_dtype); if (workspace->data.shape.empty()) { workspace->data.shape = plan->getWorkspaceShape(); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 842f72065f..08be5b9d48 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -147,8 +147,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const gamma.data.dtype, // otype x.data.shape[0], // batch_size x.data.shape[1], // hidden_size - multiprocessorCount, zero_centered_gamma, is_aligned, - NVTE_DELAYED_TENSOR_SCALING, true, + multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, gamma_in_weight_dtype); if (workspace->data.shape.empty()) { From b1b78e54795f6e286d569c9a07aad60147496249 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 23 May 2025 12:05:28 -0700 Subject: [PATCH 6/6] More verbose code and comments Signed-off-by: Przemek Tredak --- tests/cpp/test_common.cu | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index ad43f64c35..4c78ebedb5 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -694,17 +694,19 @@ std::pair getTolerances(const DType type) { template void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { - // Check how many times the distribution will call the generator - int k = 0; + // Check how many RNG calls are required to generate one uniform random value + int rng_calls_per_val = 0; { std::mt19937 gen1 = *gen, gen2 = *gen; std::uniform_real_distribution<> dis(-2.0, 1.0); const float _ = dis(gen1); while (gen2 != gen1) { auto _ = gen2(); - ++k; + ++rng_calls_per_val; } } + + // Generate uniform random values in parallel #pragma omp parallel proc_bind(spread) { std::mt19937 gen_local = *gen; @@ -713,14 +715,14 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { const int chunk_size = (size + threads_num - 1) / threads_num; const int idx_min = chunk_size * thread_ID; const int idx_max = std::min(chunk_size * (thread_ID + 1), static_cast(size)); - gen_local.discard(idx_min * k); + gen_local.discard(idx_min * rng_calls_per_val); std::uniform_real_distribution<> dis(-2.0, 1.0); for (int i = idx_min; i < idx_max; ++i) { data[i] = static_cast(dis(gen_local)); } } - gen->discard(size * k); + gen->discard(size * rng_calls_per_val); } void fillUniform(Tensor *t) {