Skip to content

Commit

Permalink
Simplify handling of "define box by compositing" (#2420)
Browse files Browse the repository at this point in the history
When there is a "define box by compositing", in principle, we should be
indexing into the _imagined_ partitioned IterDomain, as shown below"


![](https://raw.githubusercontent.com/NVIDIA/Fuser/main/doc/dev/tma/box-by-compositing.svg)

However, we were not able to do that in the past because indexing only
supports backward propagation. With `TensorIndexer` and
`AbstractTensor`, we can easily create the _imagined_ partitioned
ValGroup and index into it.
So there is no need to use the state machine approach to manually
combine indices, instead, the index of the _imagined_ partitioned
ValGroup can be directly used as the index for the TMA PTX instructions.
  • Loading branch information
zasdfgbnm committed Jun 26, 2024
1 parent 46676f7 commit 5037d8a
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 210 deletions.
194 changes: 91 additions & 103 deletions csrc/device_lower/analysis/tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <abstract_tensor.h>
#include <device_lower/analysis/tma.h>
#include <device_lower/lower2device.h>
#include <id_model/id_model.h>
Expand Down Expand Up @@ -33,7 +34,6 @@ int64_t getCpAsyncBulkTensorSwizzleSize(TensorView* smem_tv) {
return 1;
}

// Detect the pattern x = expr(..., x, ...), exclude this expr from the result.
// TODO: We should use utilities in val_graph_visitor.h so that we don't have
// to manually filter out cyclic expr groups
ExprGroups acyclicExprGroups(const ValGraph& id_graph, const ExprGroups& egs) {
Expand Down Expand Up @@ -79,7 +79,7 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {
int64_t itemsize = dataTypeSize(gmem_tv->dtype());

const TensorIndexer& indexer = GpuLower::current()->tensorIndexer();
const ValGraph& id_graph = indexer.traversalGraph();
ValGraph& id_graph = indexer.traversalGraph();

auto gmem_alloc_dom = TensorDomain::noBroadcasts(
TensorDomain::noReductions(gmem_tv->getMaybeAllocationDomain()));
Expand Down Expand Up @@ -204,9 +204,9 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {
// is a box dimension defined by compositing.
std::vector<ValGroup> tma_groups;
std::unordered_map<ValGroup, ValGroup> tma_g_to_box_g;
std::unordered_map<ValGroup, ValGroup> tma_g_to_tile_g;
std::unordered_map<ValGroup, ValGroup> tma_g_to_stride_g;
std::unordered_map<ValGroup, ValGroup> tma_g_to_partitioned_g;
std::unordered_map<ValGroup, std::pair<ValGroup, ValGroup>>
tma_g_to_tile_stride_g;
std::unordered_set<ValGroup> partitioned_groups;
for (const auto& tile_g : tile_groups) {
const auto& defs =
acyclicExprGroups(id_graph, id_graph.getDefinitions(tile_g));
Expand Down Expand Up @@ -241,11 +241,10 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {
tma_groups.push_back(tma_g);
tma_g_to_box_g[tma_g] = box_g;
if (stride_g != nullptr) {
tma_g_to_tile_g[tma_g] = tile_g;
tma_g_to_stride_g[tma_g] = stride_g;
tma_g_to_tile_stride_g[tma_g] = {tile_g, stride_g};
}
if (partitioned_g != nullptr) {
tma_g_to_partitioned_g[tma_g] = partitioned_g;
partitioned_groups.insert(partitioned_g);
}
}

Expand Down Expand Up @@ -342,24 +341,22 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {
}

// Frontier is now the TMA domain
const auto& tma_domain = frontier;

NVF_ERROR(
std::get<1>(tma_domain.back()),
std::get<1>(frontier.back()),
"The innermost dimension of the TMA domain must be contiguous");
NVF_ERROR(
tma_g_to_stride_g.count(std::get<0>(tma_domain.back())) == 0,
tma_g_to_tile_stride_g.count(std::get<0>(frontier.back())) == 0,
"When interleave is CU_TENSOR_MAP_INTERLEAVE_NONE ",
"(this is always the case for nvFuser now)",
", the first element of elementStrides must be one.");

// Validate that tma_domain is a superset of tma_groups, otherwise there is
// Validate that frontier is a superset of tma_groups, otherwise there is
// something wrong in the schedule.
{
std::unordered_set<ValGroup> seen;
std::unordered_set<ValGroup> pending_tma_groups(
tma_groups.begin(), tma_groups.end());
for (auto tuple : tma_domain) {
for (auto tuple : frontier) {
auto g = std::get<0>(tuple);
NVF_ERROR(
seen.insert(g).second,
Expand All @@ -378,105 +375,96 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {

// So far, we have infered the TMA domain. The size of TMA domain is not
// necessarily the dimensionality of TMA because we support defining box
// by compositing. We use a state machine to infer the dimensions of TMA.
//
// by compositing. We use AbstractTensor to further merge the TMA domain to
// the imagined TMA domain.
AbstractTensor tma_domain;
std::vector<bool> contiguity;
std::vector<Val*> global_strides;
tma_domain.domain.reserve(frontier.size());
global_strides.reserve(frontier.size());
contiguity.reserve(frontier.size());
for (auto& item : frontier) {
tma_domain.domain.emplace_back(
ValGroupAndItsGraph{std::move(std::get<0>(item)), &id_graph});
contiguity.push_back(std::get<1>(item));
global_strides.push_back(std::get<2>(item));
}
// There can only be four types of ValGroups in the TMA domain:
// - P: partitioned ValGroup
// - C: coordinate ValGroup
// - SB: strided box ValGroup
// - CB: contiguous box ValGroup
//
// For the example of the Figure 6 in doc/dev/tma.md, the TMA domain is
// [I1, I2, I3, I4, I5, I6, I7, I8, I9], and the types of these IDs are
// [ C, CB, P, C, CB, CB, C, CB, CB]
//
// The algorithm works as follows: We run a 3-state machine. The state machine
// is initialized as START. After setting the initial state, we loop through
// the TMA domain from inner to outer. During the loop, for each ValGroup we
// see, we take an action and change the state of the machine. The action and
// target state depend on the current state of the machine, and the type and
// contiguity of the ValGroup we encounter. The actions and transition of
// states are shown in the following diagram:
//
// P: create new dim
// .-------------.
// | |
// '-- [START] <-'
// CB: / ^ P: ^ \ SB/C:
// create / / create \ \ create
// new / / new dim \ \ new
// dim / / \ \ dim
// v / \ v
// .--- [PENDING BOX] -----> [PENDING COORD] <--.
// | ^ ^ SB/C: | | |
// '-----------' | create | '------------'
// CB: create new | new dim if | SB/C: create new
// dim if discontiguous | discontiguous | dim if discontiguous
// otherwise merge with | or SB | or SB, otherwise merge
// prev dim | | with prev dim
// '---------------'
// CB: create new dim
//
// There are three states in the machine. The meaning of these states are:
// - START: Everything clean, nothing pending merge.
// - PENDING BOX: Is there another contiguous box ID? I can merge it into the
// current box.
// - PENDING COORD: Is there another coordinate ID? I can merge it into the
// current dimension.
enum IDType { P, C, SB, CB };
auto gtype = [&](int64_t i) {
const auto& g = tma_domain[i].as<ValGroupAndItsGraph>().group;
return partitioned_groups.count(g)
? P
: (!tma_g_to_box_g.count(g)
? C
: (tma_g_to_tile_stride_g.count(g) ? SB : CB));
};
// merge contiguous C groups and CB groups
int64_t i = 0;
while (i < (int64_t)tma_domain.size() - 1) {
if (!contiguity[i]) {
i++;
continue;
}
bool is_c = (gtype(i) == C && gtype(i + 1) == C);
bool is_cb = (gtype(i) == CB && gtype(i + 1) == CB);
if (is_c || is_cb) {
tma_domain.merge(i);
contiguity.erase(contiguity.begin() + i);
global_strides.erase(global_strides.begin() + i);
if (is_cb) {
auto g = tma_domain[i].as<ValGroupAndItsGraph>().group;
tma_g_to_box_g.emplace(g, g);
}
} else {
i++;
}
}
// merge contiguous C with SB/CB
for (auto i : c10::irange((int64_t)tma_domain.size() - 1)) {
if (!contiguity[i]) {
continue;
}
bool this_is_c = (gtype(i) == C);
bool next_is_b = (gtype(i + 1) == SB || gtype(i + 1) == CB);
if (this_is_c && next_is_b) {
auto b = tma_domain[i + 1].as<ValGroupAndItsGraph>().group;
tma_domain.merge(i);
contiguity.erase(contiguity.begin() + i);
global_strides.erase(global_strides.begin() + i);
auto g = tma_domain[i].as<ValGroupAndItsGraph>().group;
tma_g_to_box_g.emplace(g, b);
if (auto it = tma_g_to_tile_stride_g.find(b);
it != tma_g_to_tile_stride_g.end()) {
tma_g_to_tile_stride_g.emplace(g, it->second);
}
}
}

// As required by the hardware, tensors used by TMA must be in column major
std::vector<TMADim> dims;
enum { START, PENDING_BOX, PENDING_COORD } state = START;
for (auto it = tma_domain.rbegin(); it != tma_domain.rend(); it++) {
auto [g, contiguous, stride] = *it;
auto partitioned_g_it = tma_g_to_partitioned_g.find(g);
auto box_g_it = tma_g_to_box_g.find(g);
auto stride_g_it = tma_g_to_stride_g.find(g);
auto tile_g_it = tma_g_to_tile_g.find(g);
enum IDType { P, C, SB, CB };
IDType type =
(partitioned_g_it != tma_g_to_partitioned_g.end()
? P
: (box_g_it == tma_g_to_box_g.end()
? C
: (stride_g_it != tma_g_to_stride_g.end() ? SB : CB)));
bool should_create_new_dim =
!(contiguous &&
((state == PENDING_BOX && (type == CB || type == C)) ||
(state == PENDING_COORD && type == C)));

if (should_create_new_dim) {
dims.emplace_back();
dims.back().gmem_stride_bytes =
SimplifyingIrBuilder::mulExpr(stride, itemsize);
if (type == CB) {
dims.back().box = std::unique_ptr<Box>(new ContiguousBox{});
} else if (type == SB) {
dims.back().box = std::unique_ptr<Box>(new StridedBox(
box_g_it->second, tile_g_it->second, stride_g_it->second));
} else if (type == P) {
if (stride_g_it != tma_g_to_stride_g.end()) {
dims.back().box = std::unique_ptr<Box>(new StridedBox(
box_g_it->second, tile_g_it->second, stride_g_it->second));
} else {
dims.back().box =
std::unique_ptr<Box>(new ContiguousBox(box_g_it->second));
}
} else {
NVF_ERROR(type == C);
dims.back().box =
std::unique_ptr<Box>(new ImplicitSizeOneBox(gmem_tv->fusion()));
}
auto sit = global_strides.rbegin();
for (auto it = tma_domain.domain.rbegin(); it != tma_domain.domain.rend();
it++, sit++) {
auto g = it->as<ValGroupAndItsGraph>().group;
dims.emplace_back();
dims.back().partitioned = g;
if (auto it = tma_g_to_box_g.find(g); it != tma_g_to_box_g.end()) {
dims.back().box = it->second;
}
dims.back().partitioned.pushBack(g);
if (type == C) {
dims.back().coordinate.pushBack(g);
} else if (type == CB) {
ContiguousBox* box = dynamic_cast<ContiguousBox*>(dims.back().box.get());
box->box_tile.pushBack(g);
if (auto it = tma_g_to_tile_stride_g.find(g);
it != tma_g_to_tile_stride_g.end()) {
dims.back().tile = it->second.first;
dims.back().stride = it->second.second;
} else {
dims.back().tile = dims.back().box;
}

state = (type == P ? START : (type == CB ? PENDING_BOX : PENDING_COORD));
dims.back().gmem_stride_bytes =
SimplifyingIrBuilder::mulExpr(*sit, itemsize);
}
return TMAInfo(
std::move(dims),
Expand Down
102 changes: 21 additions & 81 deletions csrc/device_lower/analysis/tma.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,97 +20,27 @@ namespace nvfuser {

// All ValGroups are in the traversal graph of tensor indexer

class Box {
public:
virtual ~Box() = default;
virtual Val* boxSize() const = 0;
virtual Val* tileSize() const = 0;
virtual Val* elementStride() const = 0;
};

class StridedBox : public Box {
const ValGroup box;
const ValGroup tile;
const ValGroup stride;

public:
StridedBox(ValGroup box, ValGroup tile, ValGroup stride)
: box(std::move(box)), tile(std::move(tile)), stride(std::move(stride)) {}

Val* boxSize() const override {
return box->front()->as<IterDomain>()->extent();
}
Val* tileSize() const override {
return tile->front()->as<IterDomain>()->extent();
}
Val* elementStride() const override {
return stride->front()->as<IterDomain>()->extent();
}
};

class ContiguousBox : public Box {
public:
// There is no striding split, so box == tile
ValGroups box_tile;

ContiguousBox() = default;
ContiguousBox(ValGroup g) : box_tile({std::move(g)}) {}

Val* boxSize() const override {
Val* size = nullptr;
for (const auto& g : box_tile) {
size = SimplifyingIrBuilder::mulExpr(
size, g->front()->as<IterDomain>()->extent());
}
return size;
}
Val* tileSize() const override {
return boxSize();
}
Val* elementStride() const override {
return box_tile.front()->front()->fusion()->oneVal();
}
};

class ImplicitSizeOneBox : public Box {
Fusion* const fusion;

public:
ImplicitSizeOneBox(Fusion* fusion) : fusion(fusion) {}

Val* boxSize() const override {
return fusion->oneVal();
}
Val* tileSize() const override {
return fusion->oneVal();
}
Val* elementStride() const override {
return fusion->oneVal();
}
};

struct TMADim {
ValGroups partitioned;
ValGroups coordinate;
std::unique_ptr<Box> box;
ValGroup partitioned;
ValGroup box;
ValGroup tile;
ValGroup stride;
Val* gmem_stride_bytes;

Val* tensorSize() const {
Val* size = nullptr;
for (const auto& g : partitioned) {
size = SimplifyingIrBuilder::mulExpr(
size, g->front()->as<IterDomain>()->extent());
}
return size;
return partitioned->front()->as<IterDomain>()->extent();
}
Val* boxSize() const {
return box->boxSize();
return box ? box->front()->as<IterDomain>()->extent()
: gmem_stride_bytes->fusion()->oneVal();
}
Val* tileSize() const {
return box->tileSize();
return tile ? tile->front()->as<IterDomain>()->extent()
: gmem_stride_bytes->fusion()->oneVal();
}
Val* elementStride() const {
return box->elementStride();
return stride ? stride->front()->as<IterDomain>()->extent()
: gmem_stride_bytes->fusion()->oneVal();
}
};

Expand All @@ -130,6 +60,16 @@ class TMAInfo {
return dims_;
}

std::vector<ValGroup> getTMADomain() const {
std::vector<ValGroup> result;
std::transform(
dims_.begin(),
dims_.end(),
std::back_inserter(result),
[](const auto& d) { return d.partitioned; });
return result;
}

Val* tileSizeBytes() const {
int64_t itemsize = dataTypeSize(gmem_tv_->dtype());
Val* size = IrBuilder::create<Val>(itemsize, DataType::Index);
Expand Down
Loading

0 comments on commit 5037d8a

Please sign in to comment.