From a2c3b52701345b76340f13e13a3c695137831cea Mon Sep 17 00:00:00 2001 From: Wenjing Zhu Date: Tue, 13 Sep 2022 01:42:47 -0700 Subject: [PATCH] add sinusoidal init method --- HugeCTR/embedding_storage/common.hpp | 15 +++ .../ragged_static_embedding.cu | 81 +++++++++--- HugeCTR/include/common.hpp | 4 +- HugeCTR/include/data_simulator.hpp | 7 + .../include/pybind/embedding_collection.hpp | 1 + HugeCTR/src/data_simulator.cu | 21 +++ .../test_embedding_collection.cpp | 122 ++++++++++++------ .../test_embedding_table.cpp | 1 + 8 files changed, 193 insertions(+), 59 deletions(-) diff --git a/HugeCTR/embedding_storage/common.hpp b/HugeCTR/embedding_storage/common.hpp index 91b76eb7a2..887aec6a22 100644 --- a/HugeCTR/embedding_storage/common.hpp +++ b/HugeCTR/embedding_storage/common.hpp @@ -34,6 +34,20 @@ using core::Tensor; using core::TensorList; using core::TensorScalarType; +struct UniformParams { + float up_bound; +}; +struct SinusoidalParams { + int ev_size; + int max_sequence_len; +}; + +struct EmbeddingTableInitParams { + HugeCTR::Initializer_t initializer_type; + UniformParams uniform_params; + SinusoidalParams sinus_params; +}; + struct EmbeddingTableParam { int table_id; int max_vocabulary_size; // -1 means dynamic @@ -42,5 +56,6 @@ struct EmbeddingTableParam { int64_t max_key; HugeCTR::OptParams opt_param; + EmbeddingTableInitParams init_param; }; } // namespace embedding diff --git a/HugeCTR/embedding_storage/ragged_static_embedding.cu b/HugeCTR/embedding_storage/ragged_static_embedding.cu index c4dacd67c2..c5488fb2c4 100644 --- a/HugeCTR/embedding_storage/ragged_static_embedding.cu +++ b/HugeCTR/embedding_storage/ragged_static_embedding.cu @@ -172,24 +172,71 @@ RaggedStaticEmbeddingTable::RaggedStaticEmbeddingTable( emb_table_ev_offset_.copy_from(cpu_emb_table_ev_offset); local_ev_size_list_.copy_from(cpu_local_ev_size_list); - auto uniform_init_table = [&](const curandGenerator_t &generator) { - const size_t num_tables = cpu_local_id_space_list.size(); - for (size_t embedding = 0; embedding < num_tables; embedding++) { - index_t num_keys = cpu_id_space_offset[embedding + 1] - cpu_id_space_offset[embedding]; - float up_bound = sqrt(1.f / num_keys); - size_t offset = cpu_emb_table_ev_offset[embedding]; - size_t num_elements = - cpu_emb_table_ev_offset[embedding + 1] - cpu_emb_table_ev_offset[embedding]; - - HugeCTR::UniformGenerator::fill(emb_table_.get() + offset, num_elements, -up_bound, - up_bound, gpu_resource.get_sm_count(), generator, - gpu_resource.get_stream()); + for (size_t embedding = 0; embedding < cpu_local_id_space_list.size(); embedding++) { + int id_space = cpu_local_id_space_list[embedding]; + const auto &init_param = global_emb_table_param_list[id_space].init_param; + if (init_param.initializer_type == HugeCTR::Initializer_t::Default) { + auto default_init_table = [&](const curandGenerator_t &generator) { + index_t num_keys = cpu_id_space_offset[embedding + 1] - cpu_id_space_offset[embedding]; + float up_bound = sqrt(1.f / num_keys); + size_t offset = cpu_emb_table_ev_offset[embedding]; + size_t num_elements = + cpu_emb_table_ev_offset[embedding + 1] - cpu_emb_table_ev_offset[embedding]; + + HugeCTR::UniformGenerator::fill(emb_table_.get() + offset, num_elements, + -up_bound, up_bound, gpu_resource.get_sm_count(), + generator, gpu_resource.get_stream()); + }; + + // data parallel table should use same curand seed across all gpus + if (sharding_param.table_placement_strategy == TablePlacementStrategy::DataParallel) { + default_init_table(gpu_resource.get_replica_uniform_curand_generator()); + } else { + default_init_table(gpu_resource.get_replica_variant_curand_generator()); + } + } else if (init_param.initializer_type == HugeCTR::Initializer_t::Uniform) { + auto uniform_init_table = [&](const curandGenerator_t &generator) { + float up_bound = init_param.uniform_params.up_bound; + size_t offset = cpu_emb_table_ev_offset[embedding]; + size_t num_elements = + cpu_emb_table_ev_offset[embedding + 1] - cpu_emb_table_ev_offset[embedding]; + + HugeCTR::UniformGenerator::fill(emb_table_.get() + offset, num_elements, + -up_bound, up_bound, gpu_resource.get_sm_count(), + generator, gpu_resource.get_stream()); + }; + + // data parallel table should use same curand seed across all gpus + if (sharding_param.table_placement_strategy == TablePlacementStrategy::DataParallel) { + uniform_init_table(gpu_resource.get_replica_uniform_curand_generator()); + } else { + uniform_init_table(gpu_resource.get_replica_variant_curand_generator()); + } + } else if (init_param.initializer_type == HugeCTR::Initializer_t::Sinusoidal) { + auto sinusoidal_init_table = [&] { + int max_sequence_len = init_param.sinus_params.max_sequence_len; + int ev_size = init_param.sinus_params.ev_size; + size_t offset = cpu_emb_table_ev_offset[embedding]; + size_t num_elements = + cpu_emb_table_ev_offset[embedding + 1] - cpu_emb_table_ev_offset[embedding]; + + HCTR_CHECK_HINT(max_sequence_len * ev_size == static_cast(num_elements), + "max_sequent_len * ev_size %d should equal to num_elements %d", + max_sequence_len * ev_size, static_cast(num_elements)); + HugeCTR::SinusoidalGenerator::fill( + emb_table_.get() + offset, num_elements, ev_size, max_sequence_len, + gpu_resource.get_sm_count(), gpu_resource.get_stream()); + }; + + // data parallel table should use same curand seed across all gpus + if (sharding_param.table_placement_strategy == TablePlacementStrategy::DataParallel) { + sinusoidal_init_table(); + } else { + HCTR_OWN_THROW(HugeCTR::Error_t::IllegalCall, "initializer not implemented"); + } + } else { + HCTR_OWN_THROW(HugeCTR::Error_t::IllegalCall, "initializer not implemented"); } - }; - if (sharding_param.table_placement_strategy == TablePlacementStrategy::DataParallel) { - uniform_init_table(gpu_resource.get_replica_uniform_curand_generator()); - } else { - uniform_init_table(gpu_resource.get_replica_variant_curand_generator()); } }); }); diff --git a/HugeCTR/include/common.hpp b/HugeCTR/include/common.hpp index 097695e31e..5e51b33401 100644 --- a/HugeCTR/include/common.hpp +++ b/HugeCTR/include/common.hpp @@ -170,7 +170,7 @@ enum class Embedding_t { None }; -enum class Initializer_t { Default, Uniform, XavierNorm, XavierUniform, Zero }; +enum class Initializer_t { Default, Uniform, XavierNorm, XavierUniform, Sinusoidal, Zero }; enum class TrainState_t { Init, @@ -316,4 +316,4 @@ struct DenseLayerSwitchs { DenseLayerSwitchs(bool fuse_wb_ = false) : fuse_wb(fuse_wb_) {} }; -} // namespace HugeCTR \ No newline at end of file +} // namespace HugeCTR diff --git a/HugeCTR/include/data_simulator.hpp b/HugeCTR/include/data_simulator.hpp index 6a4a2dfe6b..35a42cf8e0 100644 --- a/HugeCTR/include/data_simulator.hpp +++ b/HugeCTR/include/data_simulator.hpp @@ -30,6 +30,13 @@ class UniformGenerator { const curandGenerator_t& generator, const cudaStream_t& stream); }; +class SinusoidalGenerator { + public: + template + static void fill(T* ptr, size_t num_elements, int ev_size, int max_sequence_len, size_t sm_count, + const cudaStream_t& stream); +}; + class HostUniformGenerator { public: template diff --git a/HugeCTR/include/pybind/embedding_collection.hpp b/HugeCTR/include/pybind/embedding_collection.hpp index 5d6fced246..4550428ee6 100644 --- a/HugeCTR/include/pybind/embedding_collection.hpp +++ b/HugeCTR/include/pybind/embedding_collection.hpp @@ -39,6 +39,7 @@ class EmbeddingTableConfig { } else { param_.opt_param.optimizer = Optimizer_t::NOT_INITIALIZED; } + param_.init_param.initializer_type = HugeCTR::Initializer_t::Default; } }; diff --git a/HugeCTR/src/data_simulator.cu b/HugeCTR/src/data_simulator.cu index a221393d3d..4b75c1d11c 100644 --- a/HugeCTR/src/data_simulator.cu +++ b/HugeCTR/src/data_simulator.cu @@ -35,6 +35,27 @@ void UniformGenerator::fill(float* ptr, size_t num_elements, float a, flo transform_array<<>>(ptr, ptr, num_elements, op); } +template +__global__ void sinusoidal_kernel(T* output, int ev_size, int max_sequence_len) { + int row = blockIdx.x; + int col = threadIdx.x; + int offset = row * ev_size + col; + float log_result = __logf(10000) / (ev_size); + float exp_result = __expf(((col >> 1) << 1) * -1 * log_result); + + if (col < ev_size) { + output[offset] = (col % 2) ? (T)__cosf(exp_result * row) : (T)__sinf(exp_result * row); + } +} + +template <> +void SinusoidalGenerator::fill(float* ptr, size_t num_elements, int ev_size, + int max_sequence_len, size_t sm_count, + const cudaStream_t& stream) { + sinusoidal_kernel<<>>(ptr, ev_size, + max_sequence_len); +} + template <> void UniformGenerator::fill(Tensor2& tensor, float a, float b, size_t sm_count, const curandGenerator_t& generator, const cudaStream_t& stream) { diff --git a/test/utest/embedding_collection/test_embedding_collection.cpp b/test/utest/embedding_collection/test_embedding_collection.cpp index 56f5695e0a..f7e216dc09 100644 --- a/test/utest/embedding_collection/test_embedding_collection.cpp +++ b/test/utest/embedding_collection/test_embedding_collection.cpp @@ -23,16 +23,16 @@ #include "HugeCTR/core/hctr_impl/hctr_backend.hpp" #include "HugeCTR/embedding/embedding.hpp" #include "HugeCTR/embedding/embedding_planner.hpp" +#include "HugeCTR/embedding_storage/common.hpp" #include "HugeCTR/include/resource_managers/resource_manager_ext.hpp" #include "embedding_collection_cpu.hpp" #include "embedding_collection_utils.hpp" using namespace embedding; -std::vector get_table_param_list(int num_table, - const std::vector &table_ev_size_list, - const std::vector &table_min_key_list, - const std::vector &table_max_key_list, - core::DataType emb_type) { +std::vector get_table_param_list( + int num_table, const std::vector &table_ev_size_list, + const std::vector &table_min_key_list, const std::vector &table_max_key_list, + const std::vector init_param_list, core::DataType emb_type) { std::vector table_param_list; for (int id = 0; id < num_table; ++id) { EmbeddingTableParam table_param; @@ -46,6 +46,11 @@ std::vector get_table_param_list(int num_table, opt_param.lr = 1e-1; opt_param.scaler = (emb_type == TensorScalarType::Float16) ? 1024 : 1; table_param.opt_param = opt_param; + if (id < (int)init_param_list.size()) { + table_param.init_param = init_param_list[id]; + } else { + table_param.init_param.initializer_type = HugeCTR::Initializer_t::Default; + } table_param_list.push_back(std::move(table_param)); } return table_param_list; @@ -59,6 +64,7 @@ void embedding_collection_e2e(const std::vector device_list, const int &bat const std::vector &combiner_list, const std::vector table_min_key_list, const std::vector table_max_key_list, + const std::vector init_param_list, const std::string &plan_file) { ASSERT_TRUE(static_cast(num_table) == table_ev_size_list.size()); ASSERT_TRUE(static_cast(num_table) == table_min_key_list.size()); @@ -84,8 +90,9 @@ void embedding_collection_e2e(const std::vector device_list, const int &bat ebc_param.index_type = HugeCTR::TensorScalarTypeFunc::get_type(); ebc_param.offset_type = HugeCTR::TensorScalarTypeFunc::get_type(); ebc_param.emb_type = HugeCTR::TensorScalarTypeFunc::get_type(); - auto table_param_list = get_table_param_list(num_table, table_ev_size_list, table_min_key_list, - table_max_key_list, ebc_param.emb_type); + auto table_param_list = + get_table_param_list(num_table, table_ev_size_list, table_min_key_list, table_max_key_list, + init_param_list, ebc_param.emb_type); auto resource_manager = HugeCTR::ResourceManagerExt::create({device_list}, 0); int num_gpus = static_cast(device_list.size()); @@ -1013,75 +1020,91 @@ const std::vector hotness_list = {8, 20, 10, 5, 8}; const std::vector combiner_list = {Combiner::Sum, Combiner::Sum, Combiner::Sum, Combiner::Sum, Combiner::Sum}; const std::vector table_min_key_list = {0, 0, 0, 0, 0}; -const std::vector table_max_key_list = {100, 100, 100, 100, 100}; +const std::vector table_max_key_list = {100, 100, 500, 1000, 2000}; +EmbeddingTableInitParams init_param_default = {HugeCTR::Initializer_t::Default, UniformParams{0.0f}, + SinusoidalParams{0, 0}}; +EmbeddingTableInitParams init_param_uniform = {HugeCTR::Initializer_t::Uniform, UniformParams{1.0f}, + SinusoidalParams{0, 0}}; +EmbeddingTableInitParams init_param_sinus1 = {HugeCTR::Initializer_t::Sinusoidal, + UniformParams{0.0f}, SinusoidalParams{64, 500}}; +EmbeddingTableInitParams init_param_sinus2 = {HugeCTR::Initializer_t::Sinusoidal, + UniformParams{0.0f}, SinusoidalParams{16, 1000}}; +EmbeddingTableInitParams init_param_sinus3 = {HugeCTR::Initializer_t::Sinusoidal, + UniformParams{0.0f}, SinusoidalParams{8, 2000}}; +const std::vector init_param_list = { + init_param_default, init_param_default, init_param_default, init_param_default, + init_param_default}; +const std::vector dp_init_param_list = { + init_param_default, init_param_uniform, init_param_sinus1, init_param_sinus2, + init_param_sinus3}; TEST(test_embedding_collection, plan_0) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_0_i64) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_0_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_0_i64_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_1) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, dp_init_param_list, "/workdir/test/utest/embedding_collection/plan_1.json"); } TEST(test_embedding_collection, plan_1_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, dp_init_param_list, "/workdir/test/utest/embedding_collection/plan_1.json"); } TEST(test_embedding_collection, plan_2) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_2.json"); } TEST(test_embedding_collection, plan_2_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_2.json"); } TEST(test_embedding_collection, plan_3) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_3.json"); } TEST(test_embedding_collection, plan_3_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_3.json"); } @@ -1097,47 +1120,53 @@ const std::vector hotness_list = {8, 20, 10, 5, 8}; const std::vector combiner_list = {Combiner::Concat, Combiner::Average, Combiner::Concat, Combiner::Sum, Combiner::Sum}; const std::vector table_min_key_list = {0, 0, 0, 0, 0}; -const std::vector table_max_key_list = {1000, 1000, 1000, 1000, 1000}; +const std::vector table_max_key_list = {1000, 1000, 500, 1000, 2000}; + +EmbeddingTableInitParams init_param_default = {HugeCTR::Initializer_t::Default, UniformParams{0.0f}, + SinusoidalParams{0, 0}}; +const std::vector init_param_list = { + init_param_default, init_param_default, init_param_default, init_param_default, + init_param_default}; TEST(test_embedding_collection, plan_0_concat) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_0_concat_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_2_concat) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_2.json"); } TEST(test_embedding_collection, plan_2_concat_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_2.json"); } TEST(test_embedding_collection, plan_3_concat) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_3.json"); } TEST(test_embedding_collection, plan_3_concat_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_3.json"); } } // namespace concat_combiner @@ -1152,47 +1181,51 @@ const std::vector hotness_list = {8, 20, 10, 5, 8}; const std::vector combiner_list = {Combiner::Sum, Combiner::Average, Combiner::Sum, Combiner::Sum, Combiner::Sum}; const std::vector table_min_key_list = {0, 0, 0, 0}; -const std::vector table_max_key_list = {1000, 1000, 1000, 1000}; +const std::vector table_max_key_list = {1000, 1000, 500, 1000}; +EmbeddingTableInitParams init_param_default = {HugeCTR::Initializer_t::Default, UniformParams{0.0f}, + SinusoidalParams{0, 0}}; +const std::vector init_param_list = { + init_param_default, init_param_default, init_param_default, init_param_default}; TEST(test_embedding_collection, plan_0_share_id_space) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_0_share_id_space_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_2_share_id_space) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_2.json"); } TEST(test_embedding_collection, plan_2_share_id_space_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_2.json"); } TEST(test_embedding_collection, plan_3_share_id_space) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_3.json"); } TEST(test_embedding_collection, plan_3_share_id_space_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_3.json"); } } // namespace share_embedding_table @@ -1207,47 +1240,51 @@ const std::vector hotness_list = {8, 20, 10, 5, 8}; const std::vector combiner_list = {Combiner::Concat, Combiner::Average, Combiner::Sum, Combiner::Sum, Combiner::Sum}; const std::vector table_min_key_list = {0, 0, 0, 0}; -const std::vector table_max_key_list = {1000, 1000, 1000, 1000}; +const std::vector table_max_key_list = {1000, 1000, 500, 1000}; +EmbeddingTableInitParams init_param_default = {HugeCTR::Initializer_t::Default, UniformParams{0.0f}, + SinusoidalParams{0, 0}}; +const std::vector init_param_list = { + init_param_default, init_param_default, init_param_default, init_param_default}; TEST(test_embedding_collection, plan_0_share_id_space_and_concat) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_0_share_id_space_and_concat_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_0.json"); } TEST(test_embedding_collection, plan_2_share_id_space_and_concat) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_2.json"); } TEST(test_embedding_collection, plan_2_share_id_space_and_concat_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_2.json"); } TEST(test_embedding_collection, plan_3_share_id_space_and_concat) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_3.json"); } TEST(test_embedding_collection, plan_3_share_id_space_and_concat_half) { embedding_collection_e2e( gpus, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_3.json"); } } // namespace share_embedding_table_and_concat_combiner @@ -1277,11 +1314,16 @@ const std::vector table_min_key_list = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, const std::vector table_max_key_list = { 203931, 18598, 14092, 7012, 18977, 4, 6385, 1245, 49, 186213, 71328, 67288, 11, 2168, 7338, 61, 4, 932, 15, 204515, 141526, 199433, 60919, 9137, 71, 34}; +EmbeddingTableInitParams init_param_default = {HugeCTR::Initializer_t::Default, UniformParams{0.0f}, + SinusoidalParams{0, 0}}; +const std::vector init_param_list = { + init_param_default, init_param_default, init_param_default, init_param_default, + init_param_default}; TEST(test_embedding_collection, plan) { embedding_collection_e2e( gpus8, batch_size, num_table, table_ev_size_list, num_embedding, id_space_list, hotness_list, - combiner_list, table_min_key_list, table_max_key_list, + combiner_list, table_min_key_list, table_max_key_list, init_param_list, "/workdir/test/utest/embedding_collection/plan_criteo_8gpu.json"); } } // namespace criteo diff --git a/test/utest/embedding_collection/test_embedding_table.cpp b/test/utest/embedding_collection/test_embedding_table.cpp index 3366c234b8..f69b74a7b3 100644 --- a/test/utest/embedding_collection/test_embedding_table.cpp +++ b/test/utest/embedding_collection/test_embedding_table.cpp @@ -46,6 +46,7 @@ void test_ragged_static_embedding_table(int device_id) { param.max_key = max_vocabulary_size_list[id_space]; param.opt_param.optimizer = HugeCTR::Optimizer_t::SGD; param.opt_param.lr = 1e-1; + param.init_param.initializer_type = HugeCTR::Initializer_t::Default; param_list.push_back(param); }