Skip to content

Commit

Permalink
[Codegen][Metal] Disable cross-function call in Metal codegen
Browse files Browse the repository at this point in the history
This PR restores the Metal codegen to the one before apache#15835.
Due to there will likely be no internal function call in Metal,
we think it is safe to do so.

Verified that with this PR, the metal codegen and iPhone codegen
will not fail and will work properly.

The reason of the iPhone codegen failure is because the multiple
declarations of a same function will lead to multiple emissions
of a same structs, which is not recognizable by the metal compiler.
  • Loading branch information
MasterJH5574 committed Nov 1, 2023
1 parent 7a50c36 commit b95db83
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 36 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ jobs:
shell: bash -l {0}
run: >-
python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile'
python -m pytest -v -s 'tests/python/unittest/test_target_codegen_metal.py::test_func_with_trailing_pod_params'
- name: Minimal Metal Compile-and-Run
shell: bash -l {0}
run: >-
Expand Down
89 changes: 55 additions & 34 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ namespace codegen {

void CodeGenMetal::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
name_supply_->FreshName("v_");
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
Expand All @@ -57,15 +55,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) {
<< "};\n\n";
}

void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func,
std::ostream& os) {
void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
// NOTE: There is no inter-function calls among Metal kernels.
// For now we keep the metal codegen without inter-function call
// process.
// We can switch to follow the flow with inter-function call process
// after the Metal function declaration is properly printed.
// In Metal, for PrimFuncs with signature
// def func(A: Buffer, B: Buffer, x: int, y: float) -> None
// where there are trailing pod parameters, the codegen emits a struct
// struct func_params{ x: int; y: float; }
// for the function. In the flow of inter-function call process,
// the struct will be emitted for every time a function is declared.
// So consequently there are duplicate appearances of a same struct,
// which makes the Metal compiler unable to recognize.

// clear previous generated state.
this->InitFuncState(func);
// skip the first underscore, so SSA variable starts from _1
name_supply_->FreshName("v_");

// add to alloc buffer type.
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";

// Function header.
os << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";
this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "(";

// Buffer arguments
size_t num_buffer = 0;
Expand All @@ -77,13 +93,13 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) {
Var v = func->params[i];
if (!v.dtype().is_handle()) break;
os << " ";
this->stream << " ";
std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
PrintStorageScope(it->second, this->stream);
}
PrintType(GetType(v), os);
PrintType(GetType(v), this->stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
Expand All @@ -92,14 +108,15 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri
RegisterHandleType(v.get(), prim->dtype);
}
}
os << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
this->stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
}
// Setup normal arguments.
size_t nargs = func->params.size() - num_buffer;
std::string varg = name_supply_->FreshName("arg");
if (nargs != 0) {
std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t";
os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n";
this->stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer
<< ") ]],\n";
// declare the struct
decl_stream << "struct " << arg_buf_type << " {\n";
for (size_t i = num_buffer; i < func->params.size(); ++i) {
Expand Down Expand Up @@ -141,16 +158,22 @@ void CodeGenMetal::PrintFunctionSignature(const String& function_name, const Pri

if (work_dim != 0) {
// use ushort by default for now
os << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
os << " blockIdx [[threadgroup_position_in_grid]],\n";
os << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
os << " threadIdx [[thread_position_in_threadgroup]]\n";
stream << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " blockIdx [[threadgroup_position_in_grid]],\n";
stream << " ";
PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
stream << " threadIdx [[thread_position_in_threadgroup]]\n";
}
thread_work_dim_ = work_dim;

os << ")";
// the function scope.
stream << ") {\n";
int func_scope = this->BeginScope();
this->PrintStmt(func->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}

void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
Expand Down Expand Up @@ -295,6 +318,9 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N
}

void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
CHECK(!op->op.as<GlobalVarNode>())
<< "CodegenMetal does not support inter-function calls, "
<< "but expression " << GetRef<Call>(op) << " calls PrimFunc " << op->op;
if (op->op.same_as(builtin::reinterpret())) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
Expand Down Expand Up @@ -337,33 +363,28 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
std::string fmt = fmetal_compile ? "metallib" : "metal";

Map<GlobalVar, PrimFunc> functions;
for (auto [gvar, base_func] : mod->functions) {
ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
auto calling_conv = base_func->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";

auto prim_func = Downcast<PrimFunc>(base_func);
functions.Set(gvar, prim_func);
}
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined());
std::string func_name = global_symbol.value();

for (auto [gvar, prim_func] : functions) {
source_maker << "// Function: " << gvar->name_hint << "\n";
source_maker << "// Function: " << func_name << "\n";
CodeGenMetal cg(target);
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";

for (auto [other_gvar, other_prim_func] : functions) {
cg.DeclareFunction(other_gvar, other_prim_func);
}
cg.AddFunction(gvar, prim_func);
cg.AddFunction(kv.first, f);

std::string fsource = cg.Finish();
source_maker << fsource << "\n";
if (fmetal_compile) {
fsource = (*fmetal_compile)(fsource, target).operator std::string();
}
smap[cg.GetFunctionName(gvar)] = fsource;
smap[func_name] = fsource;
}

return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str());
Expand Down
3 changes: 1 addition & 2 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ class CodeGenMetal final : public CodeGenC {
explicit CodeGenMetal(Target target);
// override print thread tag.
void PrintArgUnionDecl();
void PrintFunctionSignature(const String& function_name, const PrimFunc& func,
std::ostream& os) override;
void AddFunction(const GlobalVar& gvar, const PrimFunc& func) final;
void InitFuncState(const PrimFunc& f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
Expand Down
23 changes: 23 additions & 0 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,28 @@ def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")):
np.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5)


@tvm.testing.requires_metal(support_required="compile-only")
def test_func_with_trailing_pod_params():
from tvm.contrib import xcode # pylint: disable=import-outside-toplevel

@T.prim_func
def func(A: T.Buffer((16), "float32"), B: T.Buffer((16), "float32"), x: T.float32):
for i in T.thread_binding(16, thread="threadIdx.x"):
with T.block("block"):
vi = T.axis.spatial(16, i)
B[vi] = A[vi] + x

@tvm.register_func("tvm_callback_metal_compile")
def compile_metal(src, target):
return xcode.compile_metal(src)

mod = tvm.IRModule({"main": func})

f = tvm.build(mod, target="metal")
src: str = f.imported_modules[0].get_source()
occurrences = src.count("struct func_kernel_args_t")
assert occurrences == 1, occurrences


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit b95db83

Please sign in to comment.