diff --git a/include/infinicore/context/context.hpp b/include/infinicore/context/context.hpp index 1612db830..e6db2ac55 100644 --- a/include/infinicore/context/context.hpp +++ b/include/infinicore/context/context.hpp @@ -11,7 +11,7 @@ namespace infinicore { namespace context { -void setDevice(Device device); +void setDevice(Device device, bool force_cpu = false); Device getDevice(); size_t getDeviceCount(Device::Type type); diff --git a/include/infinicore/device.hpp b/include/infinicore/device.hpp index 86d372066..0c2562391 100644 --- a/include/infinicore/device.hpp +++ b/include/infinicore/device.hpp @@ -39,6 +39,10 @@ class Device { bool operator!=(const Device &other) const; + inline static Device cpu() { + return Device(Type::CPU, 0); + } + private: Type type_; diff --git a/include/infinicore/nn/parameter.hpp b/include/infinicore/nn/parameter.hpp index 214fa85cb..9d87c8866 100644 --- a/include/infinicore/nn/parameter.hpp +++ b/include/infinicore/nn/parameter.hpp @@ -9,8 +9,19 @@ class Parameter : public Tensor { Parameter(const Shape &shape, const DataType &dtype, - const Device &device); + const Device &device, + Size tp_dim = 0, + Size tp_rank = 0, + Size tp_size = 1); void load_blob(const void *data); + + void load(const Tensor &tensor); + +protected: + // Tensor parallel configs + Size tp_dim_; // dimension partitioned + Size tp_rank_; // rank of this partition among tp group + Size tp_size_; // total number of partitions }; } // namespace infinicore::nn diff --git a/python/infinicore/context.py b/python/infinicore/context.py index d74a839f2..1c92c4bd6 100644 --- a/python/infinicore/context.py +++ b/python/infinicore/context.py @@ -23,13 +23,13 @@ def get_device_count(device_type): return _infinicore.get_device_count(infinicore.device(device_type)._underlying.type) -def set_device(device): +def set_device(device, force_cpu=False): """Set the current active device. Args: device: The device to set as active """ - _infinicore.set_device(device._underlying) + _infinicore.set_device(device._underlying, force_cpu) def sync_stream(): diff --git a/src/infinicore-test/memory_test.cc b/src/infinicore-test/memory_test.cc index e0903969c..2029e0a81 100644 --- a/src/infinicore-test/memory_test.cc +++ b/src/infinicore-test/memory_test.cc @@ -709,9 +709,6 @@ TestResult PerformanceTest::testMemoryCopyPerformance() { return false; } - // Initialize source data - std::memset(src_memory->data(), 0xAB, data_size); - auto start = std::chrono::high_resolution_clock::now(); // Perform memory copies diff --git a/src/infinicore-test/test_nn_module.cc b/src/infinicore-test/test_nn_module.cc index 7e9be04f6..acc55522e 100644 --- a/src/infinicore-test/test_nn_module.cc +++ b/src/infinicore-test/test_nn_module.cc @@ -3,6 +3,20 @@ namespace infinicore::test { +// Helper function to format shape for logging +inline std::string formatShape(const std::vector &shape) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < shape.size(); ++i) { + if (i > 0) { + oss << ", "; + } + oss << shape[i]; + } + oss << "]"; + return oss.str(); +} + // Test 1: Basic module operations (creation, parameters, state_dict, load_state_dict) TestResult NNModuleTest::testBasicModuleCreation() { return measureTime("BasicModuleOperations", [this]() { @@ -115,6 +129,174 @@ TestResult NNModuleTest::testBasicModuleCreation() { }); } +TestResult NNModuleTest::testTensorParallelParameters() { + return measureTime("TensorParallelParameters", [this]() { + try { + spdlog::info("=========================================="); + spdlog::info("Testing Tensor Parallel Parameters"); + spdlog::info("=========================================="); + + auto device = infinicore::context::getDevice(); + + spdlog::info("Test Tensor Parallel Parameter"); + // Case 1: Partition along dimension 0 (row-wise partitioning) + infinicore::nn::Parameter param_dim0({8, 4}, infinicore::DataType::F32, device, 0, 0, 2); + if (param_dim0->shape() != std::vector({4, 4})) { + spdlog::error("TP dim0: Expected shape [4, 4], got [{}]", formatShape(param_dim0->shape())); + return false; + } + spdlog::info("✓ TP dim0 parameter created with correct partitioned shape"); + // Case 2: Partition along dimension 1 (column-wise partitioning) + infinicore::nn::Parameter param_dim1({8, 4}, infinicore::DataType::F32, device, 1, 0, 2); + if (param_dim1->shape() != std::vector({8, 2})) { + spdlog::error("TP dim1: Expected shape [8, 2], got [{}]", formatShape(param_dim1->shape())); + return false; + } + spdlog::info("✓ TP dim1 parameter created with correct partitioned shape"); + spdlog::info("✓ Parameter creation with tensor parallelism passed"); + + spdlog::info("Test Tensor Parallel Linear Module"); + auto w_data = std::vector(32 * 64); + auto b_data = std::vector(32); + for (size_t i = 0; i < 32; ++i) { + for (size_t j = 0; j < 64; ++j) { + w_data[i * 64 + j] = static_cast(j); + } + b_data[i] = static_cast(i); + } + { + spdlog::info("Test tp_size=4 tp_dim=0"); + Size tp_size = 4; + Size tp_dim = 0; + std::vector> tp_modules; + + for (Size tp_rank = 0; tp_rank < tp_size; ++tp_rank) { + auto module = std::make_unique(64, 32, device, tp_dim, tp_rank, tp_size); + tp_modules.push_back(std::move(module)); + } + + // Verify each partition has correct shape + for (size_t i = 0; i < tp_modules.size(); ++i) { + const auto &weight = tp_modules[i]->get_weight(); + const auto &bias = tp_modules[i]->get_bias(); + + // Weight should be partitioned along output dimension (dim 0) + if (weight->shape() != std::vector({8, 64})) { // 32/4 = 8 + spdlog::error("TP rank {}: Weight shape mismatch. Expected [8, 64], got [{}]", + i, formatShape(weight->shape())); + return false; + } + + // Bias should be partitioned along output dimension + if (bias->shape() != std::vector({8})) { // 32/4 = 8 + spdlog::error("TP rank {}: Bias shape mismatch. Expected [8], got [{}]", + i, formatShape(bias->shape())); + return false; + } + + spdlog::debug("TP rank {}: weight shape [{}], bias shape [{}]", + i, formatShape(weight->shape()), formatShape(bias->shape())); + + tp_modules[i]->load_parameter_from_blob("weight", w_data.data()); + tp_modules[i]->load_parameter_from_blob("bias", b_data.data()); + + auto weight_loaded = infinicore::Tensor::from_blob( + w_data.data(), + {32, 64}, + infinicore::DataType::F32, + infinicore::Device::cpu()) + ->narrow({{0, i * 8, 8}}) + ->to(device); // Narrow to get the partition + auto bias_loaded = infinicore::Tensor::from_blob( + b_data.data(), + {32}, + infinicore::DataType::F32, + infinicore::Device::cpu()) + ->narrow({{0, i * 8, 8}}) + ->to(device); // Narrow to get the partition + + if (!tensorsAllClose(tp_modules[i]->get_weight(), weight_loaded, 1e-6, 1e-6)) { + spdlog::error("TP rank {}: Weight values do not match after load_parameter_from_blob", i); + return false; + } + + if (!tensorsAllClose(tp_modules[i]->get_bias(), bias_loaded, 1e-6, 1e-6)) { + spdlog::error("TP rank {}: Bias values do not match after load_parameter_from_blob", i); + return false; + } + } + } + + { + spdlog::info("Test tp_size=4 tp_dim=1"); + Size tp_size = 4; + Size tp_dim = 1; + std::vector> tp_modules; + + for (Size tp_rank = 0; tp_rank < tp_size; ++tp_rank) { + auto module = std::make_unique(64, 32, device, tp_dim, tp_rank, tp_size); + tp_modules.push_back(std::move(module)); + } + + // Verify each partition has correct shape + for (size_t i = 0; i < tp_modules.size(); ++i) { + const auto &weight = tp_modules[i]->get_weight(); + const auto &bias = tp_modules[i]->get_bias(); + + // Weight should be partitioned along output dimension (dim 0) + if (weight->shape() != std::vector({32, 16})) { // 64/4 = 16 + spdlog::error("TP rank {}: Weight shape mismatch. Expected [32, 16], got [{}]", + i, formatShape(weight->shape())); + return false; + } + + // Bias should be partitioned along output dimension + if (bias->shape() != std::vector({32})) { // Bias not partitioned when tp_dim=1 + spdlog::error("TP rank {}: Bias shape mismatch. Expected [32], got [{}]", + i, formatShape(bias->shape())); + return false; + } + + spdlog::debug("TP rank {}: weight shape [{}], bias shape [{}]", + i, formatShape(weight->shape()), formatShape(bias->shape())); + ; + tp_modules[i]->load_parameter_from_blob("weight", w_data.data()); + tp_modules[i]->load_parameter_from_blob("bias", b_data.data()); + + auto weight_loaded = infinicore::Tensor::from_blob( + w_data.data(), + {32, 64}, + infinicore::DataType::F32, + infinicore::Device::cpu()) + ->narrow({{1, i * 16, 16}}) + ->to(device); // Narrow to get the partition + auto bias_loaded = infinicore::Tensor::from_blob( + b_data.data(), + {32}, + infinicore::DataType::F32, + infinicore::Device::cpu()) + ->to(device); // Narrow to get the partition + if (!tensorsAllClose(tp_modules[i]->get_weight(), weight_loaded, 1e-6, 1e-6)) { + spdlog::error("TP rank {}: Weight values do not match after load_parameter_from_blob", i); + return false; + } + if (!tensorsAllClose(tp_modules[i]->get_bias(), bias_loaded, 1e-6, 1e-6)) { + spdlog::error("TP rank {}: Bias values do not match after load_parameter_from_blob", i); + return false; + } + } + } + + spdlog::info("=== All Tensor Parallel Parameter Tests Passed ==="); + return true; + + } catch (const std::exception &e) { + spdlog::error("Exception in testTensorParallelParameters: {}", e.what()); + return false; + } + }); +} + // Test 2: Advanced load state dict functionality (hierarchical modules) TestResult NNModuleTest::testLoadStateDict() { return measureTime("AdvancedLoadStateDict", [this]() { @@ -384,6 +566,8 @@ TestResult NNModuleTest::testParameterLoading() { return false; } + MockLinearModule module_row_parallel(3, 2, infinicore::Device(), 0, 1, 2); + spdlog::info("Parameter loading test passed"); return true; } catch (const std::exception &e) { @@ -1708,16 +1892,17 @@ TestResult NNModuleTest::run() { << "InfiniCore nn::Module Test Suite\n" << "==============================================" << std::endl; - results.push_back(testBasicModuleCreation()); // Merged: creation + parameters + state_dict + load - results.push_back(testLoadStateDict()); // Advanced: hierarchical modules - results.push_back(testModuleHierarchy()); // Demonstrates hierarchical construction - results.push_back(testParameterLoading()); // Blob loading - results.push_back(testModuleLinear()); // Linear module comprehensive test - results.push_back(testModuleEmbedding()); // Embedding module test - results.push_back(testModuleRMSNorm()); // RMSNorm module test - results.push_back(testModuleRoPE()); // RoPE module test - results.push_back(testDtypeAssertion()); // Dtype assertion test - results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test + results.push_back(testBasicModuleCreation()); // Merged: creation + parameters + state_dict + load + results.push_back(testTensorParallelParameters()); // Tensor-parallel parameters + results.push_back(testLoadStateDict()); // Advanced: hierarchical modules + results.push_back(testModuleHierarchy()); // Demonstrates hierarchical construction + results.push_back(testParameterLoading()); // Blob loading + results.push_back(testModuleLinear()); // Linear module comprehensive test + results.push_back(testModuleEmbedding()); // Embedding module test + results.push_back(testModuleRMSNorm()); // RMSNorm module test + results.push_back(testModuleRoPE()); // RoPE module test + results.push_back(testDtypeAssertion()); // Dtype assertion test + results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test // Check if all tests passed bool all_passed = true; diff --git a/src/infinicore-test/test_nn_module.h b/src/infinicore-test/test_nn_module.h index d9ffcd867..3707f2c23 100644 --- a/src/infinicore-test/test_nn_module.h +++ b/src/infinicore-test/test_nn_module.h @@ -26,17 +26,25 @@ class MockLinearModule : public infinicore::nn::Module { INFINICORE_NN_PARAMETER(weight); INFINICORE_NN_PARAMETER(bias); - MockLinearModule(int input_size, int output_size, const infinicore::Device &device) - : input_size_(input_size), output_size_(output_size), device_(device) { + MockLinearModule(int input_size, int output_size, const infinicore::Device &device, + Size tp_dim = 0, Size tp_rank = 0, Size tp_size = 1) + : input_size_(input_size), output_size_(output_size), device_(device), + tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) { // Initialize parameters using macros INFINICORE_NN_PARAMETER_INIT(weight, ({static_cast(output_size), static_cast(input_size)}, infinicore::DataType::F32, - device)); + device, + tp_dim_, + tp_rank_, + tp_size_)); INFINICORE_NN_PARAMETER_INIT(bias, ({static_cast(output_size)}, infinicore::DataType::F32, - device)); + device, + 0, + tp_dim == 0 ? tp_rank_ : 0, + tp_dim == 0 ? tp_size_ : 1)); } // Simple forward pass (conceptual - would need actual matrix operations) @@ -68,6 +76,10 @@ class MockLinearModule : public infinicore::nn::Module { int input_size_; int output_size_; infinicore::Device device_; + + Size tp_dim_; + Size tp_rank_; + Size tp_size_; }; class NNModuleTest : public TestFramework { @@ -76,16 +88,17 @@ class NNModuleTest : public TestFramework { std::string getName() const override { return "NNModuleTest"; } private: - TestResult testBasicModuleCreation(); // Merged: creation, parameters, state_dict, load_state_dict - TestResult testLoadStateDict(); // Advanced: hierarchical modules - TestResult testModuleHierarchy(); // Demonstrates proper hierarchical construction pattern - TestResult testParameterLoading(); // Test blob parameter loading - TestResult testModuleLinear(); // Comprehensive Linear module test - TestResult testModuleEmbedding(); // Embedding module test - TestResult testModuleRMSNorm(); // RMSNorm module test - TestResult testModuleRoPE(); // RoPE module test - TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters - TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation + TestResult testBasicModuleCreation(); // Merged: creation, parameters, state_dict, load_state_dict + TestResult testTensorParallelParameters(); // Module with tensor parallel parameters + TestResult testLoadStateDict(); // Advanced: hierarchical modules + TestResult testModuleHierarchy(); // Demonstrates proper hierarchical construction pattern + TestResult testParameterLoading(); // Test blob parameter loading + TestResult testModuleLinear(); // Comprehensive Linear module test + TestResult testModuleEmbedding(); // Embedding module test + TestResult testModuleRMSNorm(); // RMSNorm module test + TestResult testModuleRoPE(); // RoPE module test + TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters + TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation }; } // namespace infinicore::test diff --git a/src/infinicore/context/context_impl.cc b/src/infinicore/context/context_impl.cc index 0c3d1b4c8..3f8c726e1 100644 --- a/src/infinicore/context/context_impl.cc +++ b/src/infinicore/context/context_impl.cc @@ -33,11 +33,15 @@ Runtime *ContextImpl::getCpuRuntime() { return runtime_table_[int(Device::Type::CPU)][0].get(); } -void ContextImpl::setDevice(Device device) { +void ContextImpl::setDevice(Device device, bool force_cpu) { if (device == getCurrentRuntime()->device()) { // Do nothing if the device is already set. return; } + if (device == Device(Device::Type::CPU, 0) && !force_cpu) { + // if not forced, no need to switch to CPU device runtime + return; + } if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) { // Lazy initialization of runtime if never set before. @@ -83,8 +87,8 @@ ContextImpl::ContextImpl() { namespace context { -void setDevice(Device device) { - ContextImpl::singleton().setDevice(device); +void setDevice(Device device, bool force_cpu) { + ContextImpl::singleton().setDevice(device, force_cpu); } Device getDevice() { diff --git a/src/infinicore/context/context_impl.hpp b/src/infinicore/context/context_impl.hpp index ea9fbae66..b9ae2b47f 100644 --- a/src/infinicore/context/context_impl.hpp +++ b/src/infinicore/context/context_impl.hpp @@ -21,7 +21,7 @@ class ContextImpl { Runtime *getCpuRuntime(); - void setDevice(Device); + void setDevice(Device, bool force_cpu = false); size_t getDeviceCount(Device::Type type); diff --git a/src/infinicore/nn/module.cc b/src/infinicore/nn/module.cc index 64252a538..711f39e53 100644 --- a/src/infinicore/nn/module.cc +++ b/src/infinicore/nn/module.cc @@ -1,4 +1,5 @@ #include "infinicore/nn/module.hpp" +#include #include namespace infinicore::nn { @@ -21,28 +22,28 @@ void Module::load_state_dict(const std::unordered_map &_sta // Look up the corresponding tensor in the input state dict using the full name auto it = _state_dict.find(param_full_name); if (it != _state_dict.end()) { - // Assert dtype matches - if (param->dtype() != it->second->dtype()) { - throw std::runtime_error( - "dtype mismatch for parameter '" + param_full_name + "': " - "expected " - + std::to_string(static_cast(param->dtype())) + ", got " + std::to_string(static_cast(it->second->dtype()))); - } - param->copy_from(it->second); + this->load_parameter(param_full_name, it->second); + } else { + spdlog::warn("Parameter '{}' provided but not found in module.", param_full_name); } } } void Module::load_parameter(const std::string &name, const Tensor ¶m) { - auto existing_param = parameters_[name]; - // Assert dtype matches - if (existing_param->dtype() != param->dtype()) { - throw std::runtime_error( - "dtype mismatch for parameter '" + name + "': " - "expected " - + std::to_string(static_cast(existing_param->dtype())) + ", got " + std::to_string(static_cast(param->dtype()))); + auto it = parameters_.find(name); + if (it != parameters_.end()) { + auto existing_param = it->second; + // Assert dtype matches + if (existing_param->dtype() != param->dtype()) { + throw std::runtime_error( + "dtype mismatch for parameter '" + name + "': " + "expected " + + std::to_string(static_cast(existing_param->dtype())) + ", got " + std::to_string(static_cast(param->dtype()))); + } + existing_param.load(param); + } else { + throw std::runtime_error("Parameter '" + name + "' not found in module."); } - existing_param->copy_from(param); } void Module::load_parameter_from_blob(const std::string &name, const void *data) { diff --git a/src/infinicore/nn/parameter.cc b/src/infinicore/nn/parameter.cc index 25b141c16..8db0473ef 100644 --- a/src/infinicore/nn/parameter.cc +++ b/src/infinicore/nn/parameter.cc @@ -3,29 +3,64 @@ #include "infinicore/context/context.hpp" #include +#include namespace infinicore::nn { Parameter::Parameter() : Tensor(Tensor::empty({}, DataType::F32, Device(Device::Type::CPU, 0), false)) { } +inline Shape get_partipion_shape_(const Shape &shape, Size tp_dim, Size tp_size) { + if (tp_size <= 1) { + return shape; + } + Shape part_shape = shape; + if (tp_dim < shape.size()) { + if (shape[tp_dim] % tp_size != 0) { + throw std::runtime_error("Tensor dimension " + std::to_string(tp_dim) + " with size " + std::to_string(shape[tp_dim]) + " is not divisible by tensor parallel size " + std::to_string(tp_size) + "."); + } + part_shape[tp_dim] = shape[tp_dim] / tp_size; + } + return part_shape; +} + Parameter::Parameter( const Shape &shape, const DataType &dtype, - const Device &device) - : Tensor(Tensor::empty(shape, dtype, device, false)) { + const Device &device, + Size tp_dim, + Size tp_rank, + Size tp_size) + : Tensor(Tensor::empty(get_partipion_shape_(shape, tp_dim, tp_size), dtype, device, false)), tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) { + if (tp_rank_ >= tp_size_) { + throw std::runtime_error("Tensor parallel rank " + std::to_string(tp_rank_) + " must be less than tensor parallel size " + std::to_string(tp_size_) + "."); + } } void Parameter::load_blob(const void *data) { - auto buffer = Tensor::empty(impl_->shape(), impl_->dtype(), Device(Device::Type::CPU, 0), true); + Shape expected_shape = Shape(impl_->shape()); + expected_shape[tp_dim_] *= tp_size_; + auto buffer = Tensor::empty(expected_shape, impl_->dtype(), Device(Device::Type::CPU, 0), true); std::memcpy(buffer->data(), data, buffer->nbytes()); + this->load(buffer); +} + +void Parameter::load(const Tensor &tensor) { + Shape expected_shape = Shape(impl_->shape()); + expected_shape[tp_dim_] *= tp_size_; + + if (expected_shape != tensor->shape()) { + throw std::runtime_error("Shape mismatch when loading tensor into parameter."); + } + if (impl_->dtype() != tensor->dtype()) { + throw std::runtime_error("Dtype mismatch when loading tensor into parameter."); + } + if (tp_size_ > 1) { + impl_->copy_from(tensor->narrow({{tp_dim_, tp_rank_ * impl_->size(tp_dim_), impl_->size(tp_dim_)}})); - // If parameter is on CPU, use direct memcpy; otherwise use H2D - if (impl_->device().getType() == Device::Type::CPU) { - infinicore::context::memcpyH2H(impl_->data(), buffer->data(), buffer->nbytes()); } else { - infinicore::context::memcpyH2D(impl_->data(), buffer->data(), buffer->nbytes()); - infinicore::context::syncStream(); + impl_->copy_from(tensor); } + infinicore::context::syncStream(); } } // namespace infinicore::nn diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc index 9500548f7..f1add0c97 100644 --- a/src/infinicore/ops/embedding/embedding.cc +++ b/src/infinicore/ops/embedding/embedding.cc @@ -9,7 +9,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i ) { auto input_shape = input->shape(); auto weight_shape = weight->shape(); - auto vocab_size = weight_shape[0]; + // auto vocab_size = weight_shape[0]; auto embedding_dim = weight_shape[1]; // Assign memory to out variables @@ -23,11 +23,10 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i void embedding_(Tensor out, Tensor input, Tensor weight) { assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); - assert(infinicore::Device::Type::CPU == input->device()); + assert(infinicore::Device::Type::CPU == input->device().getType()); auto input_shape = input->shape(); auto weight_shape = weight->shape(); - auto vocab_size = weight_shape[0]; auto embedding_dim = weight_shape[1]; // Calculate the number of token @@ -47,7 +46,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) { const int64_t *input_arr = reinterpret_cast(input->data()); for (Size i = 0; i < counts; ++i) { int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < vocab_size)); + assert((idx >= 0) && (idx < weight_shape[0])); std::memcpy(out_ptr + i * bytes, weight_ptr + idx * bytes, bytes); @@ -57,7 +56,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) { for (Size i = 0; i < counts; ++i) { int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < vocab_size)); + assert((idx >= 0) && (idx < weight_shape[0])); std::memcpy(out_ptr + i * bytes, weight_ptr + idx * bytes, bytes); @@ -69,7 +68,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) { const int64_t *input_arr = reinterpret_cast(input->data()); for (Size i = 0; i < counts; ++i) { int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < vocab_size)); + assert((idx >= 0) && (idx < weight_shape[0])); context::memcpyD2D(out_ptr + i * bytes, weight_ptr + idx * bytes, bytes); @@ -78,7 +77,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) { const int32_t *input_arr = reinterpret_cast(input->data()); for (Size i = 0; i < counts; ++i) { int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < vocab_size)); + assert((idx >= 0) && (idx < weight_shape[0])); context::memcpyD2D(out_ptr + i * bytes, weight_ptr + idx * bytes, bytes); diff --git a/src/infinicore/ops/rearrange/rearrange.cc b/src/infinicore/ops/rearrange/rearrange.cc index fe9cb4e99..c70a9e930 100644 --- a/src/infinicore/ops/rearrange/rearrange.cc +++ b/src/infinicore/ops/rearrange/rearrange.cc @@ -1,4 +1,5 @@ #include "infinicore/ops/rearrange.hpp" +#include "../../utils.hpp" namespace infinicore::op { @@ -8,7 +9,9 @@ common::OpDispatcher &Rearrange::dispatcher() { }; void Rearrange::execute(Tensor y, Tensor x) { - dispatcher().lookup(context::getDevice().getType())(y, x); + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); + infinicore::context::setDevice(y->device()); + dispatcher().lookup(y->device().getType())(y, x); } Tensor rearrange(Tensor x) { diff --git a/src/infinicore/ops/rearrange/rearrange_infiniop.cc b/src/infinicore/ops/rearrange/rearrange_infiniop.cc index a7d0717e4..c7317a100 100644 --- a/src/infinicore/ops/rearrange/rearrange_infiniop.cc +++ b/src/infinicore/ops/rearrange/rearrange_infiniop.cc @@ -18,8 +18,8 @@ thread_local common::OpCache caches( void calculate(Tensor y, Tensor x) { size_t seed = hash_combine(y, x); - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); + auto device_type = y->device().getType(); + auto device_index = y->device().getIndex(); auto &cache = caches.getCache(device_type, device_index); diff --git a/src/infinicore/pybind11/context.hpp b/src/infinicore/pybind11/context.hpp index 657e30877..47daa514a 100644 --- a/src/infinicore/pybind11/context.hpp +++ b/src/infinicore/pybind11/context.hpp @@ -16,7 +16,8 @@ inline void bind(py::module &m) { py::arg("device_type")); m.def("set_device", &setDevice, "Set the current active device", - py::arg("device")); + py::arg("device"), + py::arg("force_cpu")); // Stream and handle management m.def("get_stream", &getStream, "Get the current stream"); diff --git a/src/infinicore/utils.hpp b/src/infinicore/utils.hpp index 38e262e80..0118bd3fe 100644 --- a/src/infinicore/utils.hpp +++ b/src/infinicore/utils.hpp @@ -32,3 +32,15 @@ inline struct SpdlogInitializer { throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \ } \ } while (false) + +#define INFINICORE_ASSERT_TENSORS_SAME_DEVICE(FIRST___, ...) \ + do { \ + const auto &first_device___ = (FIRST___)->device(); \ + for (const auto &tensor___ : {__VA_ARGS__}) { \ + if (first_device___ != (tensor___)->device()) { \ + throw std::runtime_error("Tensor devices mismatch " \ + + first_device___.toString() + " vs " \ + + (tensor___)->device().toString() + "."); \ + } \ + } \ + } while (0)