Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class Val;
f(Partition); \
f(Combine); \
f(Swizzle); \
f(Swizzle1D); \
f(Resize); \
f(MatmulOp); \
f(LinearOp); \
Expand Down
27 changes: 17 additions & 10 deletions csrc/multidevice/allocation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,23 @@ void shardAllocationAsLoop(
{loop_ids_to_replicate.begin(), loop_ids_to_replicate.end()});

for (auto* e : transforms) {
auto* split = dynamic_cast<Split*>(e);
NVF_ERROR(
split != nullptr,
"Expected all transform exprs to be a split between allocation and "
"loop domain during sharding propagation.");
const auto [contiguity, split_i] =
allocation_to_contiguity.erase(split->in());
auto [outer_contiguity, inner_contiguity] = splitContiguity(contiguity);
allocation_to_contiguity.insert(split_i, split->outer(), outer_contiguity);
allocation_to_contiguity.insert(split_i, split->inner(), inner_contiguity);
if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(e)) {
const auto [contiguity, swizzle_i] =
allocation_to_contiguity.erase(swizzle1d->in());
allocation_to_contiguity.insert(swizzle_i, swizzle1d->out(), contiguity);
continue;
}
if (auto* split = dynamic_cast<Split*>(e)) {
const auto [contiguity, split_i] =
allocation_to_contiguity.erase(split->in());
auto [outer_contiguity, inner_contiguity] = splitContiguity(contiguity);
allocation_to_contiguity.insert(
split_i, split->outer(), outer_contiguity);
allocation_to_contiguity.insert(
split_i, split->inner(), inner_contiguity);
continue;
}
NVF_THROW("Expected a swizzle1d or split transform. Got: ", e);
}

std::vector<IterDomain*> new_allocation_domain;
Expand Down
97 changes: 56 additions & 41 deletions csrc/multidevice/propagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,51 +181,66 @@ void transformLoopDomain(
{device_or_stream_ids.begin(), device_or_stream_ids.end()});

for (Expr* transform : transforms) {
auto* split = dynamic_cast<Split*>(transform);
NVF_ERROR(
split != nullptr,
"Expected a split transform producing the device/stream id. Got: ",
transform);

IterDomain* ref_id = split->in();
IterDomain* target_id = get_target_id(ref_id);
NVF_ERROR(
transformed_loop.contains(target_id),
"Expected the target ID, ",
target_id,
", to be in the loop domain.");

// Sharding on producer logical id is equivalent to sharding on the
// outermost consumer reshaped id iff:
// 1. The reference is outer split by num_devices.
// 2. The extent of sharded id in producer / consumer is divisible by
// split_factor. NOTE: We can only check if DID(d) is on the outer of the
// split regardless of the split_factor. However, when applying the split to
// the target, the split_factor will need to be num_devices. For e.g.: A[h]
// -> reshape -> B[a, h/a] If A is inner split `h/d`, then directly
// replaying the split on `a` will produce `a/(h/d), h/d` instead of `d,
// a/d`. So we should instead outer split by num_devices.

// Find the consumer between the reference and target.
auto [consumer_id, consumer_tv] = direction == PropagateDirection::kForward
? std::make_pair(target_id, target)
: std::make_pair(ref_id, ref);

if (hasRootToLogicalTransform(consumer_id, consumer_tv)) {
validate_split(split, target_id);
if (auto* swizzle1d = dynamic_cast<Swizzle1D*>(transform)) {
IterDomain* ref_id = swizzle1d->in();
IterDomain* target_id = get_target_id(ref_id);
NVF_ERROR(
transformed_loop.contains(target_id),
"Expected the target ID, ",
target_id,
", to be in the loop domain.");
auto it = transformed_loop.erase(target_id).second;
auto replayed_id =
IterDomain::swizzle1d(target_id, swizzle1d->parallelType());
transformed_loop.insert(it, replayed_id, std::monostate());
ref2target[swizzle1d->out()] = replayed_id;
continue;
}
if (auto* split = dynamic_cast<Split*>(transform)) {
IterDomain* ref_id = split->in();
IterDomain* target_id = get_target_id(ref_id);
NVF_ERROR(
transformed_loop.contains(target_id),
"Expected the target ID, ",
target_id,
", to be in the loop domain.");

// Sharding on producer logical id is equivalent to sharding on the
// outermost consumer reshaped id iff:
// 1. The reference is outer split by num_devices.
// 2. The extent of sharded id in producer / consumer is divisible by
// split_factor. NOTE: We can only check if DID(d) is on the outer of the
// split regardless of the split_factor. However, when applying the split
// to the target, the split_factor will need to be num_devices. For e.g.:
// A[h]
// -> reshape -> B[a, h/a] If A is inner split `h/d`, then directly
// replaying the split on `a` will produce `a/(h/d), h/d` instead of `d,
// a/d`. So we should instead outer split by num_devices.

// Find the consumer between the reference and target.
auto [consumer_id, consumer_tv] =
direction == PropagateDirection::kForward
? std::make_pair(target_id, target)
: std::make_pair(ref_id, ref);

if (hasRootToLogicalTransform(consumer_id, consumer_tv)) {
validate_split(split, target_id);
}

auto it = transformed_loop.erase(target_id).second;
auto [outer, inner] =
IterDomain::split(target_id, split->factor(), split->innerSplit());
auto it = transformed_loop.erase(target_id).second;
auto [outer, inner] =
IterDomain::split(target_id, split->factor(), split->innerSplit());

transformed_loop.insert(it, outer, std::monostate());
transformed_loop.insert(it, inner, std::monostate());
transformed_loop.insert(it, outer, std::monostate());
transformed_loop.insert(it, inner, std::monostate());

// Add mapping between ref and target for the propagated DID split.
// This is used to propagate 2D sharding and parallelization.
ref2target[split->outer()] = outer;
ref2target[split->inner()] = inner;
// Add mapping between ref and target for the propagated DID split.
// This is used to propagate 2D sharding and parallelization.
ref2target[split->outer()] = outer;
ref2target[split->inner()] = inner;
continue;
}
NVF_THROW("Expected a split or swizzle1d transform. Got: ", transform);
}

// Parallelize based on the ref2target map.
Expand Down
12 changes: 10 additions & 2 deletions csrc/transform_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ namespace nvfuser {

// Transform dispatch
void ReplayTransformations::dispatch(Expr* e) {
auto is_supported_expr = e->isOneOf<Split, Merge, Swizzle, Resize>();
NVF_ERROR(
is_supported_expr, "Invalid expr type found in transform traversal.");
(e->isOneOf<Split, Merge, Swizzle, Swizzle1D, Resize>()),
"Unsupported expr found in traversal: ",
e);
IterVisitor::dispatch(e);
}

Expand Down Expand Up @@ -185,6 +186,13 @@ void ReplayTransformations::handle(Swizzle* swizzle) {
id_map_[swizzle->outY()] = outs.second;
}

void ReplayTransformations::handle(Swizzle1D* swizzle1d) {
NVF_THROW(
"Swizzle1D replay not supported in ReplayTransformations, use ReplaySelf "
"instead: ",
swizzle1d->toString());
}

void ReplayTransformations::handle(Resize* exp) {
auto id_in = exp->in();

Expand Down
2 changes: 2 additions & 0 deletions csrc/transform_iter.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ class ReplayTransformations : public IterVisitor {
// if replaying swizzle is enabled.
void handle(Swizzle* m) override;

void handle(Swizzle1D* swizzle1d) override;

void handle(Resize* resize) override;

size_t newCounter() {
Expand Down
25 changes: 25 additions & 0 deletions csrc/transform_replay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,31 @@ class ReplaySelf : public ReplayTransformations {
NVF_THROW("Unexpected expr to self replay: ", swizzle->toString());
}

void handle(Swizzle1D* swizzle1d) override {
auto id_in = swizzle1d->in();
auto it = id_map_.find(id_in);
if (it == id_map_.end()) {
if (!error_on_failure_) {
return;
}
NVF_THROW("Transform traversal failed, dependencies not met.");
}
auto mapped = it->second;

NVF_ERROR(
loop_ids_.find(mapped) != loop_ids_.end(),
"Transform traversal failed, modified a node but it was not a loop "
"node.");

auto replayed_id = IterDomain::swizzle1d(mapped, swizzle1d->parallelType());

loop_ids_.erase(mapped);

loop_ids_[replayed_id] = newCounter();

id_map_[swizzle1d->out()] = replayed_id;
}

void handle(Resize* resize) override {
auto id_in = resize->in();

Expand Down
24 changes: 5 additions & 19 deletions tests/cpp/test_multidevice_host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "fusion.h"
#include "host_ir/container.h"
#include "host_ir/evaluator.h"
#include "host_ir/ops.h"
#include "host_ir/pass/stream_parallel_type.h"
#include "ir/all_nodes.h"
#include "multidevice/symmetric_tensor.h"
Expand Down Expand Up @@ -526,39 +527,24 @@ TEST_F(MultiDeviceTest, SwizzleWithParallelType) {
tv->outer_split(1, d);
tv->axis(1)->parallelize(ParallelType::DIDx);
tv->setAllocationDomain(tv->getLoopDomain(), true);
}
auto* allocate_out = IrBuilder::create<kir::Allocate>(
out_tv, MemoryType::Global, std::vector<Val*>({}), /*zero_init=*/true);

for (auto* tv : {in_tv, out_tv}) {
tv->outer_split(0, d);
tv->swizzle1d(0, ParallelType::DIDx);
tv->axis(0)->parallelize(ParallelType::Stream);
}

auto* allocate_out = IrBuilder::create<kir::Allocate>(
out_tv, MemoryType::Global, std::vector<Val*>({}), /*zero_init=*/true);
auto* stream_index = IrBuilder::create<Val>(DataType::Index);
auto* for_loop = IrBuilder::create<ForLoop>(
stream_index,
/*start=*/hic->zeroVal(DataType::Index),
/*stop=*/IrBuilder::create<Val>(d - 1, DataType::Index));

TensorView* in_shard =
ops::newValLike(in_tv, *in_tv->getDataType())->as<TensorView>();
hir::shardByStream(in_tv, stream_index, out_tv->definition());
TensorView* out_shard =
ops::newValLike(out_tv, *out_tv->getDataType())->as<TensorView>();

for (auto* tv : {in_shard, out_shard}) {
tv->setDeviceMesh(mesh);
tv->outer_split(1, d);
tv->axis(1)->parallelize(ParallelType::DIDx);
tv->outer_split(0, d);
tv->swizzle1d(0, ParallelType::DIDx);
tv->axis(0)->parallelize(ParallelType::Stream);
tv->setAllocationDomain(tv->getLoopDomain(), true);
}
hir::shardByStream(out_tv, stream_index, out_tv->definition());

IrBuilder::create<ShardByStream>(in_shard, in_tv, stream_index);
IrBuilder::create<ShardByStream>(out_shard, out_tv, stream_index);
auto* copy = IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, out_shard, in_shard);

Expand Down