Skip to content

Commit

Permalink
Merge branch 'main' into make_communications_IRs
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed May 10, 2024
2 parents 9e51818 + d214286 commit 0f7fa53
Show file tree
Hide file tree
Showing 31 changed files with 1,364 additions and 927 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,23 @@ PyPI: [https://pypi.org/project/nvfuser/](https://pypi.org/search/?q=nvfuser)
Docs: https://github.com/NVIDIA/Fuser/wiki

Supported compilers:
- gcc 11.4+
- clang14+

**GCC:**

We support all "supported releases" of gcc as specified in [the official site](https://gcc.gnu.org/).
As of 5/3/2024, they are:

- gcc 11.4
- gcc 12.3
- gcc 13.2
- gcc 14.1

**Clang:**

- clang 14+

Supported C++ standard:

- C++17
- C++20

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/python/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def load_matmul_problems():
"config", load_matmul_problems(), ids=lambda val: "_".join(str(v) for v in val)
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_reduction_nvf_benchmark(
def test_matmul_nvf_benchmark(
benchmark,
config: tuple,
dtype: torch.dtype,
Expand Down
42 changes: 26 additions & 16 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,7 @@ static MmaInputSmemSwizzle getSwizzleMode(TensorView* tv) {
// Reference for smem strides:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#strides
void IndexLowering::handle(const MmaOp* mma) {
constexpr int64_t core_matrix_outer_size = 8;
Val* a = nullptr;
Val* b = nullptr;
if (mma->inA()->as<TensorView>()->getMemoryType() == MemoryType::Shared) {
Expand All @@ -1583,13 +1584,19 @@ void IndexLowering::handle(const MmaOp* mma) {
auto tv = mma->inA()->as<TensorView>();
auto base_addr = IrBuilder::baseAddressExpr(tv);
auto swizzle = getSwizzleMode(tv);
int64_t stride_bytes =
8L * getBytesFromSwizzle(swizzle); // swizzle period in bytes
int64_t leading_bytes = /*8x8 items each core matrix*/ 64L *
/*number of core matrices*/ (getM(mma->macro()) / 8L) *
/*bytes per item*/ 2L;
if (swizzle != MmaInputSmemSwizzle::None) {
// TODO: why???!!!
int64_t leading_bytes = core_matrix_outer_size *
getBytesFromSwizzle(swizzle); // swizzle period in bytes
int64_t inner_size =
(mma->layout() == MmaLayout::TT || mma->layout() == MmaLayout::TN)
? getK(mma->macro())
: getM(mma->macro());
int64_t stride_bytes = core_matrix_outer_size *
/*number of core matrices, rounded up to handle padding */
roundUpToMultiple(inner_size * /*bytes per item*/ 2L,
getBytesFromSwizzle(swizzle));
if (swizzle == MmaInputSmemSwizzle::None &&
(mma->layout() == MmaLayout::NT || mma->layout() == MmaLayout::NN)) {
// tnspA and tnspB is ignored for NoSwizzle mode
std::swap(leading_bytes, stride_bytes);
}
auto matrix_desc = constructMatrixDescriptor(
Expand All @@ -1612,16 +1619,19 @@ void IndexLowering::handle(const MmaOp* mma) {
auto tv = mma->inB()->as<TensorView>();
auto swizzle = getSwizzleMode(tv);
auto base_addr = IrBuilder::baseAddressExpr(tv);
int64_t stride_bytes =
8L * getBytesFromSwizzle(swizzle); // swizzle period in bytes
int64_t leading_bytes = /*8x8 items each core matrix*/ 64L *
int64_t leading_bytes = core_matrix_outer_size *
getBytesFromSwizzle(swizzle); // swizzle period in bytes
int64_t inner_size =
(mma->layout() == MmaLayout::TN || mma->layout() == MmaLayout::NN)
? getK(mma->macro())
: getN(mma->macro());
int64_t stride_bytes = core_matrix_outer_size *
/*number of core matrices, rounded up to handle padding */
roundUpToMultiple(getN(mma->macro()) / 8L,
getBytesFromSwizzle(swizzle) / 16L) *
/*bytes per item*/ 2L;
if (swizzle != MmaInputSmemSwizzle::None &&
(mma->layout() == MmaLayout::TT || mma->layout() == MmaLayout::TN)) {
// TODO: why???!!!
roundUpToMultiple(inner_size * /*bytes per item*/ 2L,
getBytesFromSwizzle(swizzle));
if (swizzle == MmaInputSmemSwizzle::None &&
(mma->layout() == MmaLayout::TT || mma->layout() == MmaLayout::NT)) {
// tnspA and tnspB is ignored for NoSwizzle mode
std::swap(leading_bytes, stride_bytes);
}
auto matrix_desc = constructMatrixDescriptor(
Expand Down
12 changes: 6 additions & 6 deletions csrc/device_lower/pass/inline_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,18 +223,18 @@ class LowerToInlinePtx : public kir::ExprMutator {
/*scaleB=*/IrBuilder::create<Val>(1, DataType::Int32)};
auto layout = *mma->layout();
if (a_on_smem) {
// tnspA: if not K-major, then needs transpose
// tnspA
if (layout == MmaLayout::TT || layout == MmaLayout::TN) {
inputs.push_back(IrBuilder::create<Val>(1, DataType::Int32));
} else {
inputs.push_back(IrBuilder::create<Val>(0, DataType::Int32));
} else {
inputs.push_back(IrBuilder::create<Val>(1, DataType::Int32));
}
}
// tnspB: if not K-major, then needs transpose
// tnspB
if (layout == MmaLayout::TN || layout == MmaLayout::NN) {
inputs.push_back(IrBuilder::create<Val>(1, DataType::Int32));
} else {
inputs.push_back(IrBuilder::create<Val>(0, DataType::Int32));
} else {
inputs.push_back(IrBuilder::create<Val>(1, DataType::Int32));
}
registerInsertBefore(
mma,
Expand Down
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class Val;
f(Swizzle); \
f(Swizzle2D); \
f(Resize); \
f(MatmulOp); \
f(Communication);
#define DISPATCH_FOR_ALL_KIR_EXPRS(f) \
f(Allocate); \
Expand Down
18 changes: 13 additions & 5 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,20 @@ class NVF_API TensorView : public Val {
return merge(axis, axis + 1);
}

// Flatten the axis from `from` to `to` into a single axis.
// Both `from` and `to` are inclusive.
TensorView* flatten(int64_t from = 0, int64_t to = -1);

// Reorder axes according to old2new[old_pos] = new_pos
TensorView* reorder(const std::unordered_map<int64_t, int64_t>& old2new);
TensorView* reorder(
const std::initializer_list<std::pair<const int64_t, int64_t>>& old2new);

// Reorder axes based on the vector permutation.
// In terms of the function above, this can be seen as ol2new[index] =
// permutation[index]
TensorView* reorder(const std::vector<int64_t>& permutation);
TensorView* reorder(const std::initializer_list<int64_t>& permutation);

//! Swizzle the rectangular tile defined by the iterdomains corresponding
//! to the 2 given indices.
Expand Down Expand Up @@ -411,11 +423,7 @@ class NVF_API TensorView : public Val {
//! have a matching thread swizzle with the mma operand/result.
//! More detail on usage see [WarpMmaSwizzler] in scheduler/mma_utils.h .
void applyMmaSwizzle(MmaOperand operand);
// TODO: what is transpose 2? Why do we need it?
void applyMmaSwizzle(
MmaInputSmemSwizzle swizzle,
bool transpose,
bool transpose2 = false);
void applyMmaSwizzle(MmaInputSmemSwizzle swizzle);

//! Returns if this tensor view has swizzle operator on its tensor domain.
//! This is the temporary flag for indicating that the new swizzle
Expand Down
33 changes: 33 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2255,4 +2255,37 @@ class NVF_API CatOp : public Expr {
Val* getPred(int input_idx) const;
};

//! Matmul Operator to be expression evaluated without decomposition.
class MatmulOp : public Expr {
public:
using Expr::Expr;

MatmulOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b);

NVFUSER_DECLARE_CLONE_AND_CREATE

const char* getOpString() const override {
return "MatmulOp";
}

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

Val* out() const {
return output(0);
}

Val* inA() const {
return input(0);
}

Val* inB() const {
return input(1);
}

std::vector<PolymorphicValue> evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const override;
};

} // namespace nvfuser
29 changes: 29 additions & 0 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4472,4 +4472,33 @@ std::vector<PolymorphicValue> CatOp::evaluate(
return {at::cat(unpadded_inputs, concat_dim)};
}

MatmulOp::MatmulOp(IrBuilderPasskey passkey, Val* out, Val* in_a, Val* in_b)
: Expr(passkey) {
addOutput(out);
addInput(in_a);
addInput(in_b);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(MatmulOp)

std::string MatmulOp::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << out()->toString() << "\n";
indent(ss, indent_size + 1) << " = matmul(" << inA()->toString() << ",\n";
indent(ss, indent_size + 1) << " " << inB()->toString() << ")\n";
return ss.str();
}

std::string MatmulOp::toInlineString(int indent_size) const {
NVF_CHECK(false, "Tensor op can not be printed inline");
}

std::vector<PolymorphicValue> MatmulOp::evaluate(
const ExpressionEvaluator& ee,
const std::vector<PolymorphicValue>& inputs) const {
const auto a = inputs.at(0).as<at::Tensor>();
const auto b = inputs.at(1).as<at::Tensor>();
return {at::matmul(a, b)};
}

} // namespace nvfuser
Loading

0 comments on commit 0f7fa53

Please sign in to comment.