From 1ed0e86d2ce6b84500c724b1580bac64469d1347 Mon Sep 17 00:00:00 2001 From: Meghan Cowan Date: Wed, 26 Jun 2024 10:48:31 -0700 Subject: [PATCH] Tensor parallel MLP (#2360) Manually sharded tensor parallel multilayer perception layer. Input is manually translated and sharded mlp layer taken from nanoGPT. See https://github.com/NVIDIA/Fuser/issues/2199 for where we get the initial compute trace. --- csrc/expr_evaluator.cpp | 7 +- csrc/multidevice/executor.h | 5 +- csrc/multidevice/utils.cpp | 16 ++- csrc/tensor_view.cpp | 12 ++ tests/cpp/test_multidevice_matmul.cpp | 190 +++++++++++++++++++++++++- 5 files changed, 220 insertions(+), 10 deletions(-) diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index edf55cb202d..6cf566804fa 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -155,9 +155,14 @@ void ExpressionEvaluator::bind_( id->toString(), "is sharded and must have size 1, but input tensor has size ", t.size(i)); + NVF_CHECK( + tv->getDeviceMesh().size() > 0, + "TV ", + tv->toString(), + " has an empty DeviceMesh with DID parallelization") bind_( logical_domain[i]->extent(), - (int)tv->getDeviceMesh().vector().size(), + (int)tv->getDeviceMesh().size(), evaluate_validate); } else { bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index a90f357883b..bfcbc1ce7b9 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -30,11 +30,10 @@ namespace nvfuser { parallel type ParallelType::DIDx We make the following assumptions on the Fusion: - - Only the outmost (non-reduction) axis is allowed to be parallelized + - Only one (non-reduction) axis is allowed to be parallelized with ParallelType::DIDx. Moreover, this axis cannot be split/merged. - We only support 1D device meshes for now - - We only support TensorView, not Scalars - - We only support static shapes + - We only support TensorViews in communication segments. Summary of the different steps performed by the MultiDeviceExecutor: I. At instantiation: diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 901905f64f1..719d80e2c98 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -328,6 +328,9 @@ void propagateShardings(Fusion* fusion) { for (auto expr : fusion->exprs()) { auto inputs = ir_utils::filterByType(expr->inputs()); auto outputs = ir_utils::filterByType(expr->outputs()); + if (inputs.empty()) { + continue; + } TensorView* input_with_mesh = nullptr; for (auto tv : inputs) { NVF_CHECK( @@ -522,11 +525,14 @@ void unshard(Fusion* fusion) { std::set involvedDevices(Expr* expr) { std::set ret; - for (const auto& tvs : {expr->inputs(), expr->outputs()}) { - for (auto val : tvs) { - NVF_ERROR(val->isA(), "Val is not a TensorView"); - auto tv = val->as(); - NVF_ERROR(tv->hasDeviceMesh(), "the TensorView has no device mesh"); + for (const auto& tvs : + {ir_utils::filterByType(expr->inputs()), + ir_utils::filterByType(expr->outputs())}) { + for (auto* tv : tvs) { + NVF_ERROR( + tv->hasDeviceMesh(), + "the TensorView has no device mesh: ", + tv->toString()); auto& mesh = tv->getDeviceMesh().vector(); std::copy(mesh.begin(), mesh.end(), std::inserter(ret, ret.end())); } diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 596371053b1..d2b09b433da 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -1068,6 +1068,10 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) { consumer->setDomain(replayed_consumer_pair.first); + if (consumer->hasDeviceMesh()) { + producer->setDeviceMesh(consumer->getDeviceMesh()); + } + return producer; } @@ -1108,6 +1112,10 @@ TensorView* TensorView::cacheFork() { IrBuilder::createInContainer( container(), LoadStoreOpType::Set, new_output, this); + if (this->hasDeviceMesh()) { + new_output->setDeviceMesh(this->getDeviceMesh()); + } + // The new TV becomes an output. // New TV has global memory type. // This TV has local memory type. @@ -1188,6 +1196,10 @@ TensorView* TensorView::cacheAfter( // Set domain of producer - No Change TensorView* producer = this; + if (producer->hasDeviceMesh()) { + consumer->setDeviceMesh(producer->getDeviceMesh()); + } + // Insert consumer - Cache_After (CA) - after this TV. // Before: This TV -> [Use Op] -> Next TV // After: This TV -> [Set Op] -> New CA TV -> [Use Op] -> Next TV diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index 204302b1923..48877affa0e 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -32,7 +32,8 @@ namespace nvfuser { -class DistributedMatmulTest : public MultiDeviceTest { +class DistributedMatmulTest : public MultiDeviceTest, + public testing::WithParamInterface { protected: DistributedMatmulTest() : num_devices_(communicator_->size()) {} @@ -404,4 +405,191 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) { ->heuristic(); EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); } + +TEST_P(DistributedMatmulTest, MLP_Layer) { + bool use_aten_matmul = GetParam(); + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto mesh = DeviceMesh::createForNumDevices(communicator_->size()); + + int64_t sb = 64; // sequence * batch + int64_t h = 128; + int64_t h4 = 4 * h; + + // TODO: error with dynamic shape + // C++ exception with description "ext_opt.hasValue() INTERNAL ASSERT FAILED + // at "csrc/dynamic_transform.cpp":276, Could not evaluate dynamic extent: i3 + // Exception raised from DynamicTransformConcretizationInfo at + // csrc/dynamic_transform.cpp:276 + TensorView* x = makeContigConcreteTensor({sb, h}, DataType::BFloat16); + TensorView* w0 = makeContigConcreteTensor( + {num_devices_, h4 / num_devices_, h}, DataType::BFloat16); + TensorView* b0 = makeContigConcreteTensor( + {num_devices_, h4 / num_devices_}, DataType::BFloat16); + TensorView* w1 = makeContigConcreteTensor( + {num_devices_, h, h4 / num_devices_}, DataType::BFloat16); + TensorView* b1 = makeContigConcreteTensor({h}, DataType::BFloat16); + fusion->addInput(x); + fusion->addInput(w0); + fusion->addInput(b0); + fusion->addInput(w1); + fusion->addInput(b1); + + // Linear #1 + TensorView* matmul1; + if (use_aten_matmul) { + // TODO: use linear op instead + TensorView* w0_t = transpose(w0, 2, 1); + matmul1 = matmul(x, w0_t); + } else { + TensorView* linear_int0 = broadcast(x, {true, false, true, false}); + TensorView* linear_int1 = broadcast(w0, {false, true, false, false}); + TensorView* linear_int2 = mul(linear_int0, linear_int1); + matmul1 = sum(linear_int2, {-1}); + // TODO: linear_int0 has a bcast device axis that the sharding propagation + // pass misses. + linear_int0->setDeviceMesh(mesh); + linear_int0->axis(0)->parallelize(ParallelType::DIDx); + } + TensorView* b0_bcast = broadcast(b0, {false, true, false}); + TensorView* linear1 = add(matmul1, b0_bcast); + + TensorView* linear1_ = castOp(DataType::Float, linear1); + TensorView* gelu = tanh_gelu(linear1_); + TensorView* gelu_ = castOp(DataType::BFloat16, gelu); + + // Linear #2 + TensorView* local_matmul2; + if (use_aten_matmul) { + TensorView* w1_t = transpose(w1, 1, 2); + local_matmul2 = matmul(gelu_, w1_t); + } else { + // segment_set required to ensure the matmul scheduler is called + gelu_ = segment_set(gelu_); + TensorView* linear2_int0 = broadcast(gelu_, {false, false, true, false}); + TensorView* linear2_int1 = broadcast(w1, {false, true, false, false}); + TensorView* linear2_int2 = mul(linear2_int0, linear2_int1); + local_matmul2 = sum(linear2_int2, {-1}); + } + + TensorView* matmul2 = sum(local_matmul2, {0}); // Allreduce + TensorView* bcast_bias = broadcast(b1, {true, false}); + TensorView* linear2 = add(matmul2, bcast_bias); + + // Dropout + // Note: Propagation breaks at rand_like because it creates a fresh TV. + // Temporarily this prevents us from using dropout composite node. + TensorView* linear2_ = castOp(DataType::Float, linear2); + constexpr double kProb = 0.1; + constexpr double kScale = 1.0 / (1.0 - kProb); + Val* philox_seed = fusion->zeroVal(); + Val* philox_offset = fusion->zeroVal(); + TensorView* rand_vals = rand_like(linear2_, philox_seed, philox_offset); + TensorView* mask = lt(rand_vals, IrBuilder::create(1.0 - kProb)); + TensorView* apply_mask = mul(linear2_, mask); + TensorView* dropout = mul(apply_mask, IrBuilder::create(kScale)); + + fusion->addOutput(linear1); + fusion->addOutput(gelu); + fusion->addOutput(linear2); + fusion->addOutput(dropout); + + // Manually shard inputs: x, w0, b0, w1, b1 + // outputs: linear1, gelu, linear2, dropout + // TVs where sharding changes: matmul2 + // (TODO) TVs where sharding propagation breaks down: + // linear_int0 = broadcasts where a device dim axis is broadcasted. + // rand_vals => rand_like creates a fresh new TV. + + // TVs replicated on each device. + auto tv_inputs = {x, b1, matmul2, linear2, rand_vals, dropout}; + for (auto tv : tv_inputs) { + tv->setDeviceMesh(mesh); + } + + // TVs sharded on the outermost dimension. + auto tvs = {w0, b0, w1, linear1, gelu, gelu_}; + for (auto tv : tvs) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + + const auto options = at::TensorOptions() + .dtype(c10::ScalarType::BFloat16) + .device(at::kCUDA, communicator_->local_rank()); + auto x_ = at::randn({sb, h}, options); + auto w0_ = at::randn({h4, h}, options); + auto b0_ = at::randn({h4}, options); + auto w1_ = at::randn({h, h4}, options); + auto b1_ = at::randn({h}, options); + + std::vector inputs = { + x_, + shardTensor( + w0_.view({num_devices_, h4 / num_devices_, h}), + w0, + communicator_->deviceId()), + shardTensor( + b0_.view({num_devices_, h4 / num_devices_}), + b0, + communicator_->deviceId()), + shardTensor( + w1_.view({h, num_devices_, h4 / num_devices_}).transpose(1, 0), + w1, + communicator_->deviceId()), + b1_}; + at::manual_seed(0); + auto linear1_aten = + at::linear(x_.to(at::kDouble), w0_.to(at::kDouble), b0_.to(at::kDouble)); + auto gelu_aten = at::gelu(linear1_aten.to(at::kFloat), "tanh"); + auto linear2_aten = at::linear( + gelu_aten.to(at::kBFloat16).to(at::kDouble), + w1_.to(at::kDouble), + b1_.to(at::kDouble)); + auto dropout_aten = at::dropout(linear2_aten.to(at::kFloat), kProb, true); + std::vector expected_outputs = { + shardTensor( + at::transpose( + linear1_aten.view({sb, num_devices_, h4 / num_devices_}), 1, 0), + linear1, + communicator_->deviceId()), + shardTensor( + at::transpose( + gelu_aten.view({sb, num_devices_, h4 / num_devices_}), 1, 0), + gelu, + communicator_->deviceId()), + linear2_aten, + dropout_aten}; + + at::manual_seed(0); + MultiDeviceExecutor runtime( + std::move(fusion), *communicator_, executor_params_); + auto outputs = runtime.runWithInput(inputs); + + // Bump up the tolerance - the second matmul carries + // the numerical error from the prior matmul + auto tolerance_overwrite = ValidationConstants(); + std::array, 20> relaxed_sum_tol; + for (auto& arr : relaxed_sum_tol) { + arr = {128, 3.0}; + } + tolerance_overwrite.sum_tolerances_float = relaxed_sum_tol; + + testValidate( + runtime.completeFusion(), + outputs, + inputs, + expected_outputs, + __LINE__, + __FILE__, + "", + LaunchParams(), + tolerance_overwrite); +} + +INSTANTIATE_TEST_SUITE_P( + , + DistributedMatmulTest, + testing::Bool(), + testing::PrintToStringParamName()); } // namespace nvfuser