Skip to content

Commit

Permalink
[Codegen] Add shuffle for webgpu and metal
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Oct 27, 2023
1 parent de56d8c commit e43090b
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 37 deletions.
47 changes: 45 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ void CodeGenC::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base
stream << ref << " = " << value << ";\n";
}

void CodeGenC::PrintVecConstructor(DataType t, std::ostream& os) { // NOLINT(*)
PrintType(t, os);
}

std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType target) {
if (from == target) return value;
std::ostringstream os;
Expand Down Expand Up @@ -869,8 +873,47 @@ void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
os << ")";
}

void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
LOG(FATAL) << "Shuffle: not supported ";
void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT(*)
// Shuffle support
// vec = concat(vectors)
// result = (vec[indices[0]], vec[indices[1]], ...)
//
// print shuffle as:
// target_dtype(e0, e1, e2, .. en)

// construct the concat
std::vector<std::string> concat_vec;
// NOTE: important to print expr first
// in case each expr have their own nested expressions
// print each elements
for (const PrimExpr& vec : op->vectors) {
std::string vec_value = this->PrintExpr(vec);
if (vec.dtype().lanes() == 1) {
concat_vec.push_back(vec_value);
} else {
// print out each element
for (int i = 0; i < vec.dtype().lanes(); ++i) {
// access i-th element of each vector
std::ostringstream vec_elem_strm;
vec_elem_strm << vec_value << "[" << i << "]";
concat_vec.push_back(vec_elem_strm.str());
}
}
}
if (op->indices.size() == 1) {
// This is an extract element
os << concat_vec[Downcast<IntImm>(op->indices[0])->value];
} else {
// Print the shuffle as vector constructor
// vec(e0, e1, e2, .. en)
PrintVecConstructor(op->dtype, os);
os << '(';
for (size_t i = 0; i < op->indices.size(); ++i) {
if (i != 0) os << ", ";
os << concat_vec[Downcast<IntImm>(op->indices[i])->value];
}
os << ')';
}
}

void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
// print store of single element.
virtual void PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value);
// print vector constructor
virtual void PrintVecConstructor(DataType t, std::ostream& os);
// Get a cast type from to
virtual std::string CastFromTo(std::string value, DataType from, DataType target);
// Get load of single element with expression
Expand Down
47 changes: 13 additions & 34 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}

void CodeGenCUDA::PrintVecConstructor(DataType t, std::ostream& os) {
os << "make_";
PrintType(t, os);
}

void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) { // NOLINT(*)
// Delcare the result.
Expand Down Expand Up @@ -1156,8 +1161,7 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {

void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
os << "(make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << "(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
Expand All @@ -1184,8 +1188,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO

if (op->dtype.is_float16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
Expand All @@ -1197,8 +1200,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO

if (op->dtype.is_bfloat16()) {
std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 2; ++i) {
if (i != 0) os << ", ";
Expand Down Expand Up @@ -1230,8 +1232,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
os << "(int)" << v;
}
} else if (op->lanes == 16 || op->lanes == 32) {
os << "make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes / 8; ++i) {
if (i != 0) os << ", ";
Expand All @@ -1253,8 +1254,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
}

std::string v = PrintExpr(op->value);
os << "make_";
PrintType(op->dtype, os);
PrintVecConstructor(op->dtype, os);
os << '(';
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
Expand All @@ -1263,24 +1263,6 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
os << ')';
}

void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
std::vector<std::string> to_shuffle(op->vectors.size());
for (int i = 0, e = op->vectors.size(); i < e; ++i) {
ICHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
to_shuffle[i] = PrintExpr(op->vectors[i]);
}
os << "make_";
PrintType(op->dtype, os);
os << '(';
for (int i = 0, e = op->indices.size(); i < e; ++i) {
const int64_t* val = as_const_int(op->indices[i]);
ICHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size());
if (i != 0) os << ", ";
os << to_shuffle[*val];
}
os << ')';
}

void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
// Non-vector cases.
if (!op->dtype.is_vector()) {
Expand Down Expand Up @@ -1459,8 +1441,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val

if (t.is_float16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
PrintVecConstructor(t, os);
os << '(';
}
if (i % 2 == 0) {
Expand All @@ -1478,8 +1459,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val

if (t.is_bfloat16()) {
if (i == 0) {
os << "make_";
PrintType(t, os);
PrintVecConstructor(t, os);
os << '(';
}
if (i % 2 == 0) {
Expand All @@ -1496,8 +1476,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
}

if (i == 0) {
os << "make_";
PrintType(t, os);
PrintVecConstructor(t, os);
os << "(";
}
os << value;
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class CodeGenCUDA final : public CodeGenC {
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintVecConstructor(DataType t, std::ostream& os) final;
void PrintVecElemLoad(const std::string& vec, DataType t, int i,
std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
Expand All @@ -62,7 +63,6 @@ class CodeGenCUDA final : public CodeGenC {
std::string CastFromTo(std::string value, DataType from, DataType target) final;
// overload visitor
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
Expand Down
4 changes: 4 additions & 0 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
*/
#include "codegen_metal.h"

#include <tvm/tir/transform.h>

#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "../../runtime/metal/metal_module.h"
Expand Down Expand Up @@ -327,6 +330,7 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO
runtime::Module BuildMetal(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));

std::ostringstream source_maker;
std::unordered_map<std::string, std::string> smap;
Expand Down

0 comments on commit e43090b

Please sign in to comment.