Skip to content

Commit

Permalink
Unit test for scheduling a Mma op where input B has an allocation dom…
Browse files Browse the repository at this point in the history
…ain. (#2208)

As a first step to supporting allocation domain in our matmul scheduler,
this unit test demonstrates how we can use reordering to extend
scheduling a Mma op when the input Tensorview (B) has an allocation
domain.

The next PR on top this will extend the scheduler to support a strided
input B in a unit test that schedules a matmul.
  • Loading branch information
protonu committed May 7, 2024
1 parent d529d9a commit 5af7104
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 34 deletions.
8 changes: 8 additions & 0 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,14 @@ class NVF_API TensorView : public Val {

// Reorder axes according to old2new[old_pos] = new_pos
TensorView* reorder(const std::unordered_map<int64_t, int64_t>& old2new);
TensorView* reorder(
const std::initializer_list<std::pair<const int64_t, int64_t>>& old2new);

// Reorder axes based on the vector permutation.
// In terms of the function above, this can be seen as ol2new[index] =
// permutation[index]
TensorView* reorder(const std::vector<int64_t>& permutation);
TensorView* reorder(const std::initializer_list<int64_t>& permutation);

//! Swizzle the rectangular tile defined by the iterdomains corresponding
//! to the 2 given indices.
Expand Down
23 changes: 23 additions & 0 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,29 @@ TensorView* TensorView::reorder(
return this;
}

TensorView* TensorView::reorder(
const std::initializer_list<std::pair<const int64_t, int64_t>>& old2new) {
return reorder(std::unordered_map<int64_t, int64_t>(old2new));
}

// We have to convert the above permutation to a map of old2new.
TensorView* TensorView::reorder(const std::vector<int64_t>& permutation) {
std::unordered_map<int64_t, int64_t> reorder_map;
int64_t idx = 0;
std::transform(
permutation.begin(),
permutation.end(),
std::inserter(reorder_map, reorder_map.end()),
[&idx](const int64_t v) { return std::make_pair(idx++, v); });

return reorder(reorder_map);
}

TensorView* TensorView::reorder(
const std::initializer_list<int64_t>& permutation) {
return reorder(std::vector<int64_t>(permutation));
}

TensorView* TensorView::swizzle(
SwizzleType swizzle_type,
int64_t x,
Expand Down
162 changes: 128 additions & 34 deletions tests/cpp/test_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
#include <ir/all_nodes.h>
#include <ops/all_ops.h>
#include <scheduler/mma_utils.h>
#include <algorithm>
#include <iterator>
#include <unordered_map>

namespace nvfuser {

Expand Down Expand Up @@ -85,43 +88,68 @@ class MmaTest : public NVFuserFixtureParamTest<MmaTestParams> {
}
};

TEST_P(MmaTest, SingleTile) {
Fusion fusion;
FusionGuard fg(&fusion);

auto shapes = matmulAtInputShape3DTuring(
getM(macro), getN(macro), getK(macro), MmaLayout::TN);

auto tv0 = makeConcreteTensor(shapes.first, dtype);
auto tv1 = makeConcreteTensor(shapes.second, dtype);
fusion.addInput(tv0);
fusion.addInput(tv1);
std::vector<at::Tensor> scheduleCompileAndRun(
Fusion* fusion,
TensorView* tva,
TensorView* tvb,
std::pair<at::Tensor, at::Tensor> inputs,
int64_t dim_to_reduce,
MmaMacro macro,
bool propagate_backwards) {
fusion->addInput(tva);
fusion->addInput(tvb);

// [M, 1, K]
// Just doing a gmem->register copy
tv0 = set(tv0);
auto tv0 = set(tva);

// [1, N, K]
// Just doing a gmem->register copy
tv1 = set(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {2});
auto tv1 = set(tvb);

fusion.addOutput(tv2);
// Dim to reduce is 1 for [M, K, N] and 2 for [M, N, K].
auto tv2 = fusedMultiplySum(tv0, tv1, {dim_to_reduce});
fusion->addOutput(tv2);

auto mma_ops = ir_utils::getOpsOfType<MmaOp>(&fusion);
auto mma_ops = ir_utils::getOpsOfType<MmaOp>(fusion);
NVF_CHECK(
1 == mma_ops.size(),
"Invalid number of MmaOp instances in fusion definition, expected 1, got ",
mma_ops.size());
mma_ops.front()->setMacro(macro);

// In this test we don't handle input a (tv0) having
// an allocation domain.
NVF_CHECK(
!tva->hasAllocation(),
"tva cannot have an allocation domain in this test");

if (tvb->hasAllocation()) {
// Get the permutation that describes the difference
// between the rfactor domain and allocation domain.
auto b_permutation =
ir_utils::computePermutation(
tvb->getMaybeRFactorDomain(), tvb->getAllocationDomain())
.value();

// Reorder the ouput of Mma.
tv2->reorder(b_permutation);

// We have to propage the changes we made to then output back to the inputs
// of the Mma Op. Just for the purpose of demonstration we also show how
// it's equivalent to applying the transform to the input of the Mma
// directly.
if (propagate_backwards) {
scheduler_utils::BoundedDirectionalTransformPropagator::backward(
tv2, -1, {});
} else {
tv1->reorder(b_permutation);
}
}

auto tv2c = tv2->cacheBefore();

// [M, N, K] -> [N, M, K]
// [M, N, K] or [M, K, N] -> [N, M, K]
moveInnerBroadcastLeft(tv0);
tv0->applyMmaSwizzle(MmaOperand::A);

tv1->applyMmaSwizzle(MmaOperand::B);

tv0->merge(1);
Expand All @@ -133,21 +161,87 @@ TEST_P(MmaTest, SingleTile) {
tv2c->applyMmaSwizzle(MmaOperand::Accumulator);
tv2->applyMmaSwizzle(MmaOperand::Accumulator);

auto inputs = matmulAtInput3DTuring(
getM(macro),
getN(macro),
getK(macro),
MmaLayout::TN,
data_type_to_aten(dtype));

FusionExecutor fe;
fe.compileFusion(
&fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams);
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.squeeze().to(at::kFloat),
inputs.second.squeeze().to(at::kFloat),
MmaLayout::TN);
fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams);
return fe.runFusion({inputs.first, inputs.second});
}

TEST_P(MmaTest, SingleTile) {
Fusion fusion;
FusionGuard fg(&fusion);
auto M = getM(macro);
auto N = getN(macro);
auto K = getK(macro);

auto tv0 = makeConcreteTensor({M, 1, K}, dtype);
auto tv1 = makeConcreteTensor({1, N, K}, dtype);

auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
auto a_input = at::randn({M, 1, K}, options);
auto b_input = at::randn({1, N, K}, options);

auto cg_outputs = scheduleCompileAndRun(
&fusion,
tv0,
tv1,
{a_input, b_input},
2 /*dim to reduce [M, N, K]*/,
macro,
false /* propagate backwards*/);

auto tref = a_input.squeeze()
.to(at::kFloat)
.matmul(b_input.squeeze().t().to(at::kFloat));

EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

TEST_P(MmaTest, SingleTileWithStridedInput) {
Fusion fusion;
FusionGuard fg(&fusion);
auto M = getM(macro);
auto N = getN(macro);
auto K = getK(macro);

auto tv0 = makeConcreteTensor({M, K, 1}, dtype);
auto tv1 = makeConcreteTensor({1, K, N}, dtype);
tv1->setAllocationDomain({tv1->axis(0), tv1->axis(2), tv1->axis(1)}, true);

auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
auto a_input = at::randn({M, K, 1}, options);
auto b_input = at::randn({1, K, N}, options);
b_input = b_input.as_strided(b_input.sizes(), {N * K, 1, K});

auto cg_outputs = scheduleCompileAndRun(
&fusion,
tv0,
tv1,
{a_input, b_input},
1 /*dim to reduce [M, K, N]*/,
macro,
false /* propagate backwards*/);

auto tref =
a_input.squeeze().to(at::kFloat).matmul(b_input.squeeze().to(at::kFloat));

EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));

// Clear the fusion and try propagating changes to the mma output.
fusion.clear();
tv0 = makeConcreteTensor({M, K, 1}, dtype);
tv1 = makeConcreteTensor({1, K, N}, dtype);
tv1->setAllocationDomain({tv1->axis(0), tv1->axis(2), tv1->axis(1)}, true);
cg_outputs = scheduleCompileAndRun(
&fusion,
tv0,
tv1,
{a_input, b_input},
1 /*dim to reduce [M, N, K]*/,
macro,
true /* propagate backwards*/);
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

Expand Down

0 comments on commit 5af7104

Please sign in to comment.