Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU] Implement tir.dp4a with WGSL built-in function dot4I8Packed #16976

Merged
merged 8 commits into from
Jul 4, 2024
18 changes: 18 additions & 0 deletions src/target/source/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ void CodeGenWebGPU::PrintType(DataType t, std::ostream& os) { // NOLINT(*)

if (lanes != 1) {
ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenWebGPU: only allows vector with lanes in {2, 3, 4}";

if (t.is_int() && t.bits() == 8 && lanes == 4) {
os << "u32";
return;
}
os << "vec" << lanes << "<";
}

Expand Down Expand Up @@ -405,6 +410,19 @@ void CodeGenWebGPU::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLIN
this->EndScope(else_scope);
}
os << result;
} else if (op->op.same_as(builtin::call_pure_extern())) {
ICHECK_GE(op->args.size(), 1U);
const std::string& func_name = op->args[0].as<StringImmNode>()->value;
if (func_name == "__dp4a") {
if (op->args.size() != 3) {
LOG(FATAL) << "__dp4a can only accept 2 parameters (now: " << op->args.size() - 1 << ")";
} else {
os << "dot4I8Packed(" << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")";
Jiawei-Shao marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
LOG(FATAL) << "WGSL shader cannot make extern calls. Graph contains extern \""
<< Downcast<StringImm>(op->args[0]) << "\"";
}
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
Loading