Skip to content

Commit

Permalink
Tensor parallel MLP (#2360)
Browse files Browse the repository at this point in the history
Manually sharded tensor parallel multilayer perception layer.

Input is manually translated and sharded mlp layer taken from nanoGPT.
See #2199 for where we get the
initial compute trace.
  • Loading branch information
cowanmeg committed Jun 26, 2024
1 parent 5037d8a commit 1ed0e86
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 10 deletions.
7 changes: 6 additions & 1 deletion csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,14 @@ void ExpressionEvaluator::bind_(
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->getDeviceMesh().size() > 0,
"TV ",
tv->toString(),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
(int)tv->getDeviceMesh().vector().size(),
(int)tv->getDeviceMesh().size(),
evaluate_validate);
} else {
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
Expand Down
5 changes: 2 additions & 3 deletions csrc/multidevice/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ namespace nvfuser {
parallel type ParallelType::DIDx
We make the following assumptions on the Fusion:
- Only the outmost (non-reduction) axis is allowed to be parallelized
- Only one (non-reduction) axis is allowed to be parallelized
with ParallelType::DIDx. Moreover, this axis cannot be split/merged.
- We only support 1D device meshes for now
- We only support TensorView, not Scalars
- We only support static shapes
- We only support TensorViews in communication segments.
Summary of the different steps performed by the MultiDeviceExecutor:
I. At instantiation:
Expand Down
16 changes: 11 additions & 5 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ void propagateShardings(Fusion* fusion) {
for (auto expr : fusion->exprs()) {
auto inputs = ir_utils::filterByType<TensorView>(expr->inputs());
auto outputs = ir_utils::filterByType<TensorView>(expr->outputs());
if (inputs.empty()) {
continue;
}
TensorView* input_with_mesh = nullptr;
for (auto tv : inputs) {
NVF_CHECK(
Expand Down Expand Up @@ -522,11 +525,14 @@ void unshard(Fusion* fusion) {

std::set<DeviceIdxType> involvedDevices(Expr* expr) {
std::set<DeviceIdxType> ret;
for (const auto& tvs : {expr->inputs(), expr->outputs()}) {
for (auto val : tvs) {
NVF_ERROR(val->isA<TensorView>(), "Val is not a TensorView");
auto tv = val->as<TensorView>();
NVF_ERROR(tv->hasDeviceMesh(), "the TensorView has no device mesh");
for (const auto& tvs :
{ir_utils::filterByType<TensorView>(expr->inputs()),
ir_utils::filterByType<TensorView>(expr->outputs())}) {
for (auto* tv : tvs) {
NVF_ERROR(
tv->hasDeviceMesh(),
"the TensorView has no device mesh: ",
tv->toString());
auto& mesh = tv->getDeviceMesh().vector();
std::copy(mesh.begin(), mesh.end(), std::inserter(ret, ret.end()));
}
Expand Down
12 changes: 12 additions & 0 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,10 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) {

consumer->setDomain(replayed_consumer_pair.first);

if (consumer->hasDeviceMesh()) {
producer->setDeviceMesh(consumer->getDeviceMesh());
}

return producer;
}

Expand Down Expand Up @@ -1108,6 +1112,10 @@ TensorView* TensorView::cacheFork() {
IrBuilder::createInContainer<LoadStoreOp>(
container(), LoadStoreOpType::Set, new_output, this);

if (this->hasDeviceMesh()) {
new_output->setDeviceMesh(this->getDeviceMesh());
}

// The new TV becomes an output.
// New TV has global memory type.
// This TV has local memory type.
Expand Down Expand Up @@ -1188,6 +1196,10 @@ TensorView* TensorView::cacheAfter(
// Set domain of producer - No Change
TensorView* producer = this;

if (producer->hasDeviceMesh()) {
consumer->setDeviceMesh(producer->getDeviceMesh());
}

// Insert consumer - Cache_After (CA) - after this TV.
// Before: This TV -> [Use Op] -> Next TV
// After: This TV -> [Set Op] -> New CA TV -> [Use Op] -> Next TV
Expand Down
190 changes: 189 additions & 1 deletion tests/cpp/test_multidevice_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@

namespace nvfuser {

class DistributedMatmulTest : public MultiDeviceTest {
class DistributedMatmulTest : public MultiDeviceTest,
public testing::WithParamInterface<bool> {
protected:
DistributedMatmulTest() : num_devices_(communicator_->size()) {}

Expand Down Expand Up @@ -404,4 +405,191 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) {
->heuristic();
EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval);
}

TEST_P(DistributedMatmulTest, MLP_Layer) {
bool use_aten_matmul = GetParam();
std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto mesh = DeviceMesh::createForNumDevices(communicator_->size());

int64_t sb = 64; // sequence * batch
int64_t h = 128;
int64_t h4 = 4 * h;

// TODO: error with dynamic shape
// C++ exception with description "ext_opt.hasValue() INTERNAL ASSERT FAILED
// at "csrc/dynamic_transform.cpp":276, Could not evaluate dynamic extent: i3
// Exception raised from DynamicTransformConcretizationInfo at
// csrc/dynamic_transform.cpp:276
TensorView* x = makeContigConcreteTensor({sb, h}, DataType::BFloat16);
TensorView* w0 = makeContigConcreteTensor(
{num_devices_, h4 / num_devices_, h}, DataType::BFloat16);
TensorView* b0 = makeContigConcreteTensor(
{num_devices_, h4 / num_devices_}, DataType::BFloat16);
TensorView* w1 = makeContigConcreteTensor(
{num_devices_, h, h4 / num_devices_}, DataType::BFloat16);
TensorView* b1 = makeContigConcreteTensor({h}, DataType::BFloat16);
fusion->addInput(x);
fusion->addInput(w0);
fusion->addInput(b0);
fusion->addInput(w1);
fusion->addInput(b1);

// Linear #1
TensorView* matmul1;
if (use_aten_matmul) {
// TODO: use linear op instead
TensorView* w0_t = transpose(w0, 2, 1);
matmul1 = matmul(x, w0_t);
} else {
TensorView* linear_int0 = broadcast(x, {true, false, true, false});
TensorView* linear_int1 = broadcast(w0, {false, true, false, false});
TensorView* linear_int2 = mul(linear_int0, linear_int1);
matmul1 = sum(linear_int2, {-1});
// TODO: linear_int0 has a bcast device axis that the sharding propagation
// pass misses.
linear_int0->setDeviceMesh(mesh);
linear_int0->axis(0)->parallelize(ParallelType::DIDx);
}
TensorView* b0_bcast = broadcast(b0, {false, true, false});
TensorView* linear1 = add(matmul1, b0_bcast);

TensorView* linear1_ = castOp(DataType::Float, linear1);
TensorView* gelu = tanh_gelu(linear1_);
TensorView* gelu_ = castOp(DataType::BFloat16, gelu);

// Linear #2
TensorView* local_matmul2;
if (use_aten_matmul) {
TensorView* w1_t = transpose(w1, 1, 2);
local_matmul2 = matmul(gelu_, w1_t);
} else {
// segment_set required to ensure the matmul scheduler is called
gelu_ = segment_set(gelu_);
TensorView* linear2_int0 = broadcast(gelu_, {false, false, true, false});
TensorView* linear2_int1 = broadcast(w1, {false, true, false, false});
TensorView* linear2_int2 = mul(linear2_int0, linear2_int1);
local_matmul2 = sum(linear2_int2, {-1});
}

TensorView* matmul2 = sum(local_matmul2, {0}); // Allreduce
TensorView* bcast_bias = broadcast(b1, {true, false});
TensorView* linear2 = add(matmul2, bcast_bias);

// Dropout
// Note: Propagation breaks at rand_like because it creates a fresh TV.
// Temporarily this prevents us from using dropout composite node.
TensorView* linear2_ = castOp(DataType::Float, linear2);
constexpr double kProb = 0.1;
constexpr double kScale = 1.0 / (1.0 - kProb);
Val* philox_seed = fusion->zeroVal();
Val* philox_offset = fusion->zeroVal();
TensorView* rand_vals = rand_like(linear2_, philox_seed, philox_offset);
TensorView* mask = lt(rand_vals, IrBuilder::create<Val>(1.0 - kProb));
TensorView* apply_mask = mul(linear2_, mask);
TensorView* dropout = mul(apply_mask, IrBuilder::create<Val>(kScale));

fusion->addOutput(linear1);
fusion->addOutput(gelu);
fusion->addOutput(linear2);
fusion->addOutput(dropout);

// Manually shard inputs: x, w0, b0, w1, b1
// outputs: linear1, gelu, linear2, dropout
// TVs where sharding changes: matmul2
// (TODO) TVs where sharding propagation breaks down:
// linear_int0 = broadcasts where a device dim axis is broadcasted.
// rand_vals => rand_like creates a fresh new TV.

// TVs replicated on each device.
auto tv_inputs = {x, b1, matmul2, linear2, rand_vals, dropout};
for (auto tv : tv_inputs) {
tv->setDeviceMesh(mesh);
}

// TVs sharded on the outermost dimension.
auto tvs = {w0, b0, w1, linear1, gelu, gelu_};
for (auto tv : tvs) {
tv->setDeviceMesh(mesh);
tv->axis(0)->parallelize(ParallelType::DIDx);
}

const auto options = at::TensorOptions()
.dtype(c10::ScalarType::BFloat16)
.device(at::kCUDA, communicator_->local_rank());
auto x_ = at::randn({sb, h}, options);
auto w0_ = at::randn({h4, h}, options);
auto b0_ = at::randn({h4}, options);
auto w1_ = at::randn({h, h4}, options);
auto b1_ = at::randn({h}, options);

std::vector<c10::IValue> inputs = {
x_,
shardTensor(
w0_.view({num_devices_, h4 / num_devices_, h}),
w0,
communicator_->deviceId()),
shardTensor(
b0_.view({num_devices_, h4 / num_devices_}),
b0,
communicator_->deviceId()),
shardTensor(
w1_.view({h, num_devices_, h4 / num_devices_}).transpose(1, 0),
w1,
communicator_->deviceId()),
b1_};
at::manual_seed(0);
auto linear1_aten =
at::linear(x_.to(at::kDouble), w0_.to(at::kDouble), b0_.to(at::kDouble));
auto gelu_aten = at::gelu(linear1_aten.to(at::kFloat), "tanh");
auto linear2_aten = at::linear(
gelu_aten.to(at::kBFloat16).to(at::kDouble),
w1_.to(at::kDouble),
b1_.to(at::kDouble));
auto dropout_aten = at::dropout(linear2_aten.to(at::kFloat), kProb, true);
std::vector<at::Tensor> expected_outputs = {
shardTensor(
at::transpose(
linear1_aten.view({sb, num_devices_, h4 / num_devices_}), 1, 0),
linear1,
communicator_->deviceId()),
shardTensor(
at::transpose(
gelu_aten.view({sb, num_devices_, h4 / num_devices_}), 1, 0),
gelu,
communicator_->deviceId()),
linear2_aten,
dropout_aten};

at::manual_seed(0);
MultiDeviceExecutor runtime(
std::move(fusion), *communicator_, executor_params_);
auto outputs = runtime.runWithInput(inputs);

// Bump up the tolerance - the second matmul carries
// the numerical error from the prior matmul
auto tolerance_overwrite = ValidationConstants();
std::array<std::array<double, 2>, 20> relaxed_sum_tol;
for (auto& arr : relaxed_sum_tol) {
arr = {128, 3.0};
}
tolerance_overwrite.sum_tolerances_float = relaxed_sum_tol;

testValidate(
runtime.completeFusion(),
outputs,
inputs,
expected_outputs,
__LINE__,
__FILE__,
"",
LaunchParams(),
tolerance_overwrite);
}

INSTANTIATE_TEST_SUITE_P(
,
DistributedMatmulTest,
testing::Bool(),
testing::PrintToStringParamName());
} // namespace nvfuser

0 comments on commit 1ed0e86

Please sign in to comment.