Skip to content

Commit

Permalink
Merge opaque closure modules with the rest of the workqueue (#50724)
Browse files Browse the repository at this point in the history
This sticks the compiled opaque closure module into the
`compiled_functions` list of modules that we have compiled for the
particular `jl_codegen_params_t`. We probably should manage that vector
in codegen_params, since it lets us see if a particular codeinst has
already been compiled but not yet emitted.

(cherry picked from commit 441fcb1)
  • Loading branch information
pchintalapudi committed Aug 10, 2023
1 parent 8ad72d3 commit e549d74
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 100 deletions.
48 changes: 32 additions & 16 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
jl_native_code_desc_t *data = new jl_native_code_desc_t;
CompilationPolicy policy = (CompilationPolicy) _policy;
bool imaging = imaging_default() || _imaging_mode == 1;
jl_workqueue_t emitted;
jl_method_instance_t *mi = NULL;
jl_code_info_t *src = NULL;
JL_GC_PUSH1(&src);
Expand Down Expand Up @@ -334,7 +333,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
// find and prepare the source code to compile
jl_code_instance_t *codeinst = NULL;
jl_ci_cache_lookup(*cgparams, mi, params.world, &codeinst, &src);
if (src && !emitted.count(codeinst)) {
if (src && !params.compiled_functions.count(codeinst)) {
// now add it to our compilation results
JL_GC_PROMISE_ROOTED(codeinst->rettype);
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
Expand All @@ -343,13 +342,13 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
Triple(clone.getModuleUnlocked()->getTargetTriple()));
jl_llvm_functions_t decls = jl_emit_code(result_m, mi, src, codeinst->rettype, params);
if (result_m)
emitted[codeinst] = {std::move(result_m), std::move(decls)};
params.compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
}
}
}

// finally, make sure all referenced methods also get compiled or fixed up
jl_compile_workqueue(emitted, *clone.getModuleUnlocked(), params, policy);
jl_compile_workqueue(params, *clone.getModuleUnlocked(), policy);
}
JL_UNLOCK(&jl_codegen_lock); // Might GC
JL_GC_POP();
Expand All @@ -368,7 +367,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
data->jl_value_to_llvm[idx] = global.first;
idx++;
}
CreateNativeMethods += emitted.size();
CreateNativeMethods += params.compiled_functions.size();

size_t offset = gvars.size();
data->jl_external_to_llvm.resize(params.external_fns.size());
Expand All @@ -390,17 +389,34 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm

// clones the contents of the module `m` to the shadow_output collector
// while examining and recording what kind of function pointer we have
Linker L(*clone.getModuleUnlocked());
for (auto &def : emitted) {
jl_merge_module(clone, std::move(std::get<0>(def.second)));
jl_code_instance_t *this_code = def.first;
jl_llvm_functions_t decls = std::get<1>(def.second);
StringRef func = decls.functionObject;
StringRef cfunc = decls.specFunctionObject;
uint32_t func_id = 0;
uint32_t cfunc_id = 0;
if (func == "jl_fptr_args") {
func_id = -1;
{
JL_TIMING(NATIVE_AOT, NATIVE_Merge);
Linker L(*clone.getModuleUnlocked());
for (auto &def : params.compiled_functions) {
jl_merge_module(clone, std::move(std::get<0>(def.second)));
jl_code_instance_t *this_code = def.first;
jl_llvm_functions_t decls = std::get<1>(def.second);
StringRef func = decls.functionObject;
StringRef cfunc = decls.specFunctionObject;
uint32_t func_id = 0;
uint32_t cfunc_id = 0;
if (func == "jl_fptr_args") {
func_id = -1;
}
else if (func == "jl_fptr_sparam") {
func_id = -2;
}
else {
//Safe b/c context is locked by params
data->jl_sysimg_fvars.push_back(cast<Function>(clone.getModuleUnlocked()->getNamedValue(func)));
func_id = data->jl_sysimg_fvars.size();
}
if (!cfunc.empty()) {
//Safe b/c context is locked by params
data->jl_sysimg_fvars.push_back(cast<Function>(clone.getModuleUnlocked()->getNamedValue(cfunc)));
cfunc_id = data->jl_sysimg_fvars.size();
}
data->jl_fvar_map[this_code] = std::make_tuple(func_id, cfunc_id);
}
else if (func == "jl_fptr_sparam") {
func_id = -2;
Expand Down
112 changes: 48 additions & 64 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,6 @@ class jl_codectx_t {
std::vector<std::tuple<jl_cgval_t, BasicBlock *, AllocaInst *, PHINode *, jl_value_t *>> PhiNodes;
std::vector<bool> ssavalue_assigned;
std::vector<int> ssavalue_usecount;
std::vector<orc::ThreadSafeModule> oc_modules;
jl_module_t *module = NULL;
jl_typecache_t type_cache;
jl_tbaacache_t tbaa_cache;
Expand Down Expand Up @@ -4451,7 +4450,7 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
// Check if we already queued this up
auto it = ctx.call_targets.find(codeinst);
if (need_to_emit && it != ctx.call_targets.end()) {
protoname = std::get<2>(it->second)->getName();
protoname = it->second.decl->getName();
need_to_emit = cache_valid = false;
}

Expand Down Expand Up @@ -4495,7 +4494,7 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
handled = true;
if (need_to_emit) {
Function *trampoline_decl = cast<Function>(jl_Module->getNamedValue(protoname));
ctx.call_targets[codeinst] = std::make_tuple(cc, return_roots, trampoline_decl, specsig);
ctx.call_targets[codeinst] = {cc, return_roots, trampoline_decl, specsig};
}
}
}
Expand Down Expand Up @@ -5353,8 +5352,7 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
{
jl_svec_t *sig_args = NULL;
jl_value_t *sigtype = NULL;
jl_code_info_t *ir = NULL;
JL_GC_PUSH3(&sig_args, &sigtype, &ir);
JL_GC_PUSH2(&sig_args, &sigtype);

size_t nsig = 1 + jl_svec_len(argt_typ->parameters);
sig_args = jl_alloc_svec_uninit(nsig);
Expand All @@ -5376,16 +5374,25 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
JL_GC_POP();
return std::make_pair((Function*)NULL, (Function*)NULL);
}
++EmittedOpaqueClosureFunctions;

ir = jl_uncompress_ir(closure_method, ci, (jl_value_t*)inferred);
auto it = ctx.emission_context.compiled_functions.find(ci);

// TODO: Emit this inline and outline it late using LLVM's coroutine support.
orc::ThreadSafeModule closure_m = jl_create_ts_module(
name_from_method_instance(mi), ctx.emission_context.tsctx,
ctx.emission_context.imaging,
jl_Module->getDataLayout(), Triple(jl_Module->getTargetTriple()));
jl_llvm_functions_t closure_decls = emit_function(closure_m, mi, ir, rettype, ctx.emission_context);
if (it == ctx.emission_context.compiled_functions.end()) {
++EmittedOpaqueClosureFunctions;
jl_code_info_t *ir = jl_uncompress_ir(closure_method, ci, (jl_value_t*)inferred);
JL_GC_PUSH1(&ir);
// TODO: Emit this inline and outline it late using LLVM's coroutine support.
orc::ThreadSafeModule closure_m = jl_create_ts_module(
name_from_method_instance(mi), ctx.emission_context.tsctx,
ctx.emission_context.imaging,
jl_Module->getDataLayout(), Triple(jl_Module->getTargetTriple()));
jl_llvm_functions_t closure_decls = emit_function(closure_m, mi, ir, rettype, ctx.emission_context);
JL_GC_POP();
it = ctx.emission_context.compiled_functions.insert(std::make_pair(ci, std::make_pair(std::move(closure_m), std::move(closure_decls)))).first;
}

auto &closure_m = it->second.first;
auto &closure_decls = it->second.second;

assert(closure_decls.functionObject != "jl_fptr_sparam");
bool isspecsig = closure_decls.functionObject != "jl_fptr_args";
Expand Down Expand Up @@ -5416,7 +5423,6 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
specF = cast<Function>(returninfo.decl.getCallee());
}
}
ctx.oc_modules.push_back(std::move(closure_m));
JL_GC_POP();
return std::make_pair(F, specF);
}
Expand Down Expand Up @@ -5699,7 +5705,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
if (jl_is_concrete_type(env_t)) {
jl_tupletype_t *argt_typ = (jl_tupletype_t*)argt.constant;
Function *F, *specF;
std::tie(F, specF) = get_oc_function(ctx, (jl_method_t*)source.constant, (jl_datatype_t*)env_t, argt_typ, ub.constant);
std::tie(F, specF) = get_oc_function(ctx, (jl_method_t*)source.constant, (jl_tupletype_t*)env_t, argt_typ, ub.constant);
if (F) {
jl_cgval_t jlcall_ptr = mark_julia_type(ctx, F, false, jl_voidpointer_type);
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe);
Expand All @@ -5709,7 +5715,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
if (specF)
fptr = mark_julia_type(ctx, specF, false, jl_voidpointer_type);
else
fptr = mark_julia_type(ctx, (llvm::Value*)Constant::getNullValue(ctx.types().T_size), false, jl_voidpointer_type);
fptr = mark_julia_type(ctx, Constant::getNullValue(ctx.types().T_size), false, jl_voidpointer_type);

// TODO: Inline the env at the end of the opaque closure and generate a descriptor for GC
jl_cgval_t env = emit_new_struct(ctx, env_t, nargs-4, &argv.data()[4]);
Expand Down Expand Up @@ -8675,19 +8681,6 @@ static jl_llvm_functions_t
jl_Module->getFunction(FN)->setLinkage(GlobalVariable::InternalLinkage);
}

// link in opaque closure modules
for (auto &TSMod : ctx.oc_modules) {
SmallVector<std::string, 1> Exports;
TSMod.withModuleDo([&](Module &Mod) {
for (const auto &F: Mod.functions())
if (!F.isDeclaration())
Exports.push_back(F.getName().str());
});
jl_merge_module(TSM, std::move(TSMod));
for (auto FN: Exports)
jl_Module->getFunction(FN)->setLinkage(GlobalVariable::InternalLinkage);
}

JL_GC_POP();
return declarations;
}
Expand Down Expand Up @@ -8849,22 +8842,18 @@ jl_llvm_functions_t jl_emit_codeinst(


void jl_compile_workqueue(
jl_workqueue_t &emitted,
jl_codegen_params_t &params,
Module &original,
jl_codegen_params_t &params, CompilationPolicy policy)
CompilationPolicy policy)
{
JL_TIMING(CODEGEN, CODEGEN_Workqueue);
jl_code_info_t *src = NULL;
JL_GC_PUSH1(&src);
while (!params.workqueue.empty()) {
jl_code_instance_t *codeinst;
Function *protodecl;
jl_returninfo_t::CallingConv proto_cc;
bool proto_specsig;
unsigned proto_return_roots;
auto it = params.workqueue.back();
codeinst = it.first;
std::tie(proto_cc, proto_return_roots, protodecl, proto_specsig) = it.second;
auto proto = it.second;
params.workqueue.pop_back();
// try to emit code for this item from the workqueue
assert(codeinst->min_world <= params.world && codeinst->max_world >= params.world &&
Expand Down Expand Up @@ -8892,12 +8881,8 @@ void jl_compile_workqueue(
}
}
else {
auto &result = emitted[codeinst];
jl_llvm_functions_t *decls = NULL;
if (std::get<0>(result)) {
decls = &std::get<1>(result);
}
else {
auto it = params.compiled_functions.find(codeinst);
if (it == params.compiled_functions.end()) {
// Reinfer the function. The JIT came along and removed the inferred
// method body. See #34993
if (policy != CompilationPolicy::Default &&
Expand All @@ -8908,47 +8893,46 @@ void jl_compile_workqueue(
jl_create_ts_module(name_from_method_instance(codeinst->def),
params.tsctx, params.imaging,
original.getDataLayout(), Triple(original.getTargetTriple()));
result.second = jl_emit_code(result_m, codeinst->def, src, src->rettype, params);
result.first = std::move(result_m);
auto decls = jl_emit_code(result_m, codeinst->def, src, src->rettype, params);
if (result_m)
it = params.compiled_functions.insert(std::make_pair(codeinst, std::make_pair(std::move(result_m), std::move(decls)))).first;
}
}
else {
orc::ThreadSafeModule result_m =
jl_create_ts_module(name_from_method_instance(codeinst->def),
params.tsctx, params.imaging,
original.getDataLayout(), Triple(original.getTargetTriple()));
result.second = jl_emit_codeinst(result_m, codeinst, NULL, params);
result.first = std::move(result_m);
auto decls = jl_emit_codeinst(result_m, codeinst, NULL, params);
if (result_m)
it = params.compiled_functions.insert(std::make_pair(codeinst, std::make_pair(std::move(result_m), std::move(decls)))).first;
}
if (std::get<0>(result))
decls = &std::get<1>(result);
else
emitted.erase(codeinst); // undo the insert above
}
if (decls) {
if (decls->functionObject == "jl_fptr_args") {
preal_decl = decls->specFunctionObject;
if (it != params.compiled_functions.end()) {
auto &decls = it->second.second;
if (decls.functionObject == "jl_fptr_args") {
preal_decl = decls.specFunctionObject;
}
else if (decls->functionObject != "jl_fptr_sparam") {
preal_decl = decls->specFunctionObject;
else if (decls.functionObject != "jl_fptr_sparam") {
preal_decl = decls.specFunctionObject;
preal_specsig = true;
}
}
}
// patch up the prototype we emitted earlier
Module *mod = protodecl->getParent();
assert(protodecl->isDeclaration());
if (proto_specsig) {
Module *mod = proto.decl->getParent();
assert(proto.decl->isDeclaration());
if (proto.specsig) {
// expected specsig
if (!preal_specsig) {
// emit specsig-to-(jl)invoke conversion
Function *preal = emit_tojlinvoke(codeinst, mod, params);
protodecl->setLinkage(GlobalVariable::InternalLinkage);
proto.decl->setLinkage(GlobalVariable::InternalLinkage);
//protodecl->setAlwaysInline();
jl_init_function(protodecl, params.TargetTriple);
jl_init_function(proto.decl, params.TargetTriple);
size_t nrealargs = jl_nparams(codeinst->def->specTypes); // number of actual arguments being passed
// TODO: maybe this can be cached in codeinst->specfptr?
emit_cfunc_invalidate(protodecl, proto_cc, proto_return_roots, codeinst->def->specTypes, codeinst->rettype, false, nrealargs, params, preal);
emit_cfunc_invalidate(proto.decl, proto.cc, proto.return_roots, codeinst->def->specTypes, codeinst->rettype, false, nrealargs, params, preal);
preal_decl = ""; // no need to fixup the name
}
else {
Expand All @@ -8965,11 +8949,11 @@ void jl_compile_workqueue(
if (!preal_decl.empty()) {
// merge and/or rename this prototype to the real function
if (Value *specfun = mod->getNamedValue(preal_decl)) {
if (protodecl != specfun)
protodecl->replaceAllUsesWith(specfun);
if (proto.decl != specfun)
proto.decl->replaceAllUsesWith(specfun);
}
else {
protodecl->setName(preal_decl);
proto.decl->setName(preal_decl);
}
}
}
Expand Down
Loading

0 comments on commit e549d74

Please sign in to comment.