Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen] Add shuffle for cuda and metal #15998

Merged
merged 2 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 45 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ void CodeGenC::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base
stream << ref << " = " << value << ";\n";
}

void CodeGenC::PrintVecConstructor(DataType t, std::ostream& os) { // NOLINT(*)
PrintType(t, os);
}

std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
std::ostringstream os;
Expand Down Expand Up @@ -869,8 +873,47 @@ void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
os << ")";
}

void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
LOG(FATAL) << "Shuffle: not supported ";
void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT(*)
// Shuffle support
// vec = concat(vectors)
// result = (vec[indices[0]], vec[indices[1]], ...)
//
// print shuffle as:
// target_dtype(e0, e1, e2, .. en)

// construct the concat
std::vector<std::string> concat_vec;
// NOTE: important to print expr first
// in case each expr have their own nested expressions
// print each elements
for (const PrimExpr& vec : op->vectors) {
std::string vec_value = this->PrintExpr(vec);
if (vec.dtype().lanes() == 1) {
concat_vec.push_back(vec_value);
} else {
// print out each element
for (int i = 0; i < vec.dtype().lanes(); ++i) {
// access i-th element of each vector
std::ostringstream vec_elem_strm;
vec_elem_strm << vec_value << "[" << i << "]";
concat_vec.push_back(vec_elem_strm.str());
}
}
}
if (op->indices.size() == 1) {
// This is an extract element
os << concat_vec[Downcast<IntImm>(op->indices[0])->value];
} else {
// Print the shuffle as vector constructor
// vec(e0, e1, e2, .. en)
PrintVecConstructor(op->dtype, os);
os << '(';
for (size_t i = 0; i < op->indices.size(); ++i) {
if (i != 0) os << ", ";
os << concat_vec[Downcast<IntImm>(op->indices[i])->value];
}
os << ')';
}
}

void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
// print store of single element.
virtual void PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value);
// print vector constructor
virtual void PrintVecConstructor(DataType t, std::ostream& os);
// Get a cast type from to
virtual std::string CastFromTo(std::string value, DataType from, DataType target);
// Get load of single element with expression
Expand Down
49 changes: 14 additions & 35 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}

void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) {
os << "make_";
PrintType(t, os);
}

void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
// Delcare the result.
Expand Down Expand Up @@ -1156,15 +1161,14 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {

void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
os << "(make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << "(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != op->lanes - 1) os << ", ";
}
os << "))";
os << ")";
}

void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
Expand All @@ -1184,8 +1188,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO

if (op->dtype.is_float16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
Expand All @@ -1197,8 +1200,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO

if (op->dtype.is_bfloat16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
Expand Down Expand Up @@ -1230,8 +1232,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
os << "(int)" << v;
}
} else if (op->lanes == 16 || op->lanes == 32) {
os << "make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 8; ++i) {
if (i != 0) os << ", ";
Expand All @@ -1253,8 +1254,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
}

std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
Expand All @@ -1263,24 +1263,6 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
os << ')';
}

void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
std::vector<std::string> to_shuffle(op->vectors.size());
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
ICHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
to_shuffle[i] = PrintExpr(op->vectors[i]);
}
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0, e = op->indices.size(); i < e; ++i) {
const int64_t* val = as_const_int(op->indices[i]);
ICHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size());
if (i != 0) os << ", ";
os << to_shuffle[*val];
}
os << ')';
}

void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
// Non-vector cases.
if (!op->dtype.is_vector()) {
Expand Down Expand Up @@ -1459,8 +1441,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val

if (t.is_float16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
PrintVecConstructor(t, os);
os << '(';
}
if (i % 2 == 0) {
Expand All @@ -1478,8 +1459,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val

if (t.is_bfloat16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
PrintVecConstructor(t, os);
os << '(';
}
if (i % 2 == 0) {
Expand All @@ -1496,8 +1476,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
}

if (i == 0) {
os << "make_";
PrintType(t, os);
PrintVecConstructor(t, os);
os << "(";
}
os << value;
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class CodeGenCUDA final : public CodeGenC {
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecConstructor(DataType t, std::ostream& os) final;
void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
Expand All @@ -62,7 +63,6 @@ class CodeGenCUDA final : public CodeGenC {
std::string CastFromTo(std::string value, DataType from, DataType target) final;
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
Expand Down
4 changes: 4 additions & 0 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
*/
#include "codegen_metal.h"

#include <tvm/tir/transform.h>

#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "../../runtime/metal/metal_module.h"
Expand Down Expand Up @@ -327,6 +330,7 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO
runtime::Module BuildMetal(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));

std::ostringstream source_maker;
std::unordered_map<std::string, std::string> smap;
Expand Down