From 23200eb621b241416fb7bf75e8c9430e1002d047 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Sat, 10 Dec 2022 23:52:48 -0500 Subject: [PATCH] convert algorithms to SCC These places in the code can either be more efficient O(1) or more correct using something more similar to the published SCC algorithm by Tarjan for strongly connected components. --- src/gf.c | 1 + src/jitlayers.cpp | 101 ++++++++++++------------ src/staticdata_utils.c | 172 ++++++++++++++--------------------------- 3 files changed, 111 insertions(+), 163 deletions(-) diff --git a/src/gf.c b/src/gf.c index 6d705f15a482c..537677784c477 100644 --- a/src/gf.c +++ b/src/gf.c @@ -3377,6 +3377,7 @@ static jl_value_t *ml_matches(jl_methtable_t *mt, } } // then we'll merge those numbers to assign each item in the group the same number + // (similar to Kosaraju's SCC algorithm?) uint32_t groupid = 0; uint32_t grouphi = 0; for (i = 0; i < len; i++) { diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp index da5e8c58fdecd..2d8da6e755cb4 100644 --- a/src/jitlayers.cpp +++ b/src/jitlayers.cpp @@ -136,7 +136,11 @@ void jl_dump_llvm_opt_impl(void *s) **jl_ExecutionEngine->get_dump_llvm_opt_stream() = (JL_STREAM*)s; } -static void jl_add_to_ee(orc::ThreadSafeModule &M, StringMap &NewExports); +static int jl_add_to_ee( + orc::ThreadSafeModule &M, + const StringMap &NewExports, + DenseMap &Queued, + std::vector &Stack); static void jl_decorate_module(Module &M); static uint64_t getAddressForFunction(StringRef fname); @@ -228,10 +232,13 @@ static jl_callptr_t _jl_compile_codeinst( } } } + DenseMap Queued; + std::vector Stack; for (auto &def : emitted) { // Add the results to the execution engine now orc::ThreadSafeModule &M = std::get<0>(def.second); - jl_add_to_ee(M, NewExports); + jl_add_to_ee(M, NewExports, Queued, Stack); + assert(Queued.empty() && Stack.empty() && !M); } ++CompiledCodeinsts; MaxWorkqueueSize.updateMax(emitted.size()); @@ -1700,76 +1707,72 @@ static void jl_decorate_module(Module &M) { #endif } +// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable static int jl_add_to_ee( orc::ThreadSafeModule &M, - StringMap &NewExports, + const StringMap &NewExports, DenseMap &Queued, - std::vector> &ToMerge, - int depth) + std::vector &Stack) { - // DAG-sort (post-dominator) the compile to compute the minimum - // merge-module sets for linkage + // First check if the TSM is empty (already compiled) if (!M) return 0; - // First check and record if it's on the stack somewhere + // Next check and record if it is on the stack somewhere { - auto &Cycle = Queued[&M]; - if (Cycle) - return Cycle; - ToMerge.push_back({}); - Cycle = depth; + auto &Id = Queued[&M]; + if (Id) + return Id; + Stack.push_back(&M); + Id = Stack.size(); } + // Finally work out the SCC + int depth = Stack.size(); int MergeUp = depth; - // Compute the cycle-id + std::vector Children; M.withModuleDo([&](Module &m) { for (auto &F : m.global_objects()) { if (F.isDeclaration() && F.getLinkage() == GlobalValue::ExternalLinkage) { auto Callee = NewExports.find(F.getName()); if (Callee != NewExports.end()) { - auto &CM = Callee->second; - int Down = jl_add_to_ee(*CM, NewExports, Queued, ToMerge, depth + 1); - assert(Down <= depth); - if (Down && Down < MergeUp) - MergeUp = Down; + auto *CM = Callee->second; + if (*CM && CM != &M) { + auto Down = Queued.find(CM); + if (Down != Queued.end()) + MergeUp = std::min(MergeUp, Down->second); + else + Children.push_back(CM); + } } } } }); - if (MergeUp == depth) { - // Not in a cycle (or at the top of it) - Queued.erase(&M); - for (auto &CM : ToMerge.at(depth - 1)) { - assert(Queued.find(CM)->second == depth); - Queued.erase(CM); - jl_merge_module(M, std::move(*CM)); - } - jl_ExecutionEngine->addModule(std::move(M)); - MergeUp = 0; + assert(MergeUp > 0); + for (auto *CM : Children) { + int Down = jl_add_to_ee(*CM, NewExports, Queued, Stack); + assert(Down <= (int)Stack.size()); + if (Down) + MergeUp = std::min(MergeUp, Down); } - else { - // Add our frame(s) to the top of the cycle - Queued[&M] = MergeUp; - auto &Top = ToMerge.at(MergeUp - 1); - Top.push_back(&M); - for (auto &CM : ToMerge.at(depth - 1)) { - assert(Queued.find(CM)->second == depth); - Queued[CM] = MergeUp; - Top.push_back(CM); + if (MergeUp < depth) + return MergeUp; + while (1) { + // Not in a cycle (or at the top of it) + // remove SCC state and merge every CM from the cycle into M + orc::ThreadSafeModule *CM = Stack.back(); + auto it = Queued.find(CM); + assert(it->second == (int)Stack.size()); + Queued.erase(it); + Stack.pop_back(); + if ((int)Stack.size() < depth) { + assert(&M == CM); + break; } + jl_merge_module(M, std::move(*CM)); } - ToMerge.pop_back(); - return MergeUp; -} - -static void jl_add_to_ee(orc::ThreadSafeModule &M, StringMap &NewExports) -{ - DenseMap Queued; - std::vector> ToMerge; - jl_add_to_ee(M, NewExports, Queued, ToMerge, 1); - assert(!M); + jl_ExecutionEngine->addModule(std::move(M)); + return 0; } - static uint64_t getAddressForFunction(StringRef fname) { auto addr = jl_ExecutionEngine->getFunctionAddress(fname); diff --git a/src/staticdata_utils.c b/src/staticdata_utils.c index 3d02dddbd5a70..6f673563ff3ad 100644 --- a/src/staticdata_utils.c +++ b/src/staticdata_utils.c @@ -158,31 +158,10 @@ static int type_in_worklist(jl_value_t *v) JL_NOTSAFEPOINT return 0; } -static void mark_backedges_in_worklist(jl_method_instance_t *mi, htable_t *visited, int found) -{ - int oldfound = (char*)ptrhash_get(visited, mi) - (char*)HT_NOTFOUND; - if (oldfound < 3) - return; // not in-progress - ptrhash_put(visited, mi, (void*)((char*)HT_NOTFOUND + 1 + found)); -#ifndef NDEBUG - jl_module_t *mod = mi->def.module; - if (jl_is_method(mod)) - mod = ((jl_method_t*)mod)->module; - assert(jl_is_module(mod)); - assert(!mi->precompiled && jl_object_in_image((jl_value_t*)mod)); - assert(mi->backedges); -#endif - size_t i = 0, n = jl_array_len(mi->backedges); - while (i < n) { - jl_method_instance_t *be; - i = get_next_edge(mi->backedges, i, NULL, &be); - mark_backedges_in_worklist(be, visited, found); - } -} - // When we infer external method instances, ensure they link back to the -// package. Otherwise they might be, e.g., for external macros -static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, int depth) +// package. Otherwise they might be, e.g., for external macros. +// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable +static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, arraylist_t *stack) { jl_module_t *mod = mi->def.module; if (jl_is_method(mod)) @@ -202,14 +181,17 @@ static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, int found = (char*)*bp - (char*)HT_NOTFOUND; if (found) return found - 1; + arraylist_push(stack, (void*)mi); + int depth = stack->len; *bp = (void*)((char*)HT_NOTFOUND + 3 + depth); // preliminarily mark as in-progress size_t i = 0, n = jl_array_len(mi->backedges); int cycle = 0; while (i < n) { jl_method_instance_t *be; i = get_next_edge(mi->backedges, i, NULL, &be); - int child_found = has_backedge_to_worklist(be, visited, depth + 1); + int child_found = has_backedge_to_worklist(be, visited, stack); if (child_found == 1) { + // found what we were looking for, so terminate early found = 1; break; } @@ -221,22 +203,15 @@ static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, } if (!found && cycle && cycle != depth) return cycle + 2; - bp = ptrhash_bp(visited, mi); // re-acquire since rehashing might change the location - *bp = (void*)((char*)HT_NOTFOUND + 1 + found); - if (cycle) { - // If we are the top of the current cycle, now mark all other parts of - // our cycle by re-walking the backedges graph and marking all WIP - // items as found. - // Be careful to only re-walk as far as we had originally scanned above. - // Or if we found a backedge, also mark all of the other parts of the - // cycle as also having an backedge. - n = i; - i = 0; - while (i < n) { - jl_method_instance_t *be; - i = get_next_edge(mi->backedges, i, NULL, &be); - mark_backedges_in_worklist(be, visited, found); - } + // If we are the top of the current cycle, now mark all other parts of + // our cycle with what we found. + // Or if we found a backedge, also mark all of the other parts of the + // cycle as also having an backedge. + while (stack->len >= depth) { + void *mi = arraylist_pop(stack); + bp = ptrhash_bp(visited, mi); + assert((char*)*bp - (char*)HT_NOTFOUND == 4 + stack->len); + *bp = (void*)((char*)HT_NOTFOUND + 1 + found); } return found; } @@ -251,9 +226,11 @@ static jl_array_t *queue_external_cis(jl_array_t *list) return NULL; size_t i; htable_t visited; + arraylist_t stack; assert(jl_is_array(list)); size_t n0 = jl_array_len(list); htable_new(&visited, n0); + arraylist_new(&stack, 0); jl_array_t *new_specializations = jl_alloc_vec_any(0); JL_GC_PUSH1(&new_specializations); for (i = 0; i < n0; i++) { @@ -264,8 +241,9 @@ static jl_array_t *queue_external_cis(jl_array_t *list) if (jl_is_method(m)) { if (jl_object_in_image((jl_value_t*)m->module)) { if (ptrhash_get(&external_mis, mi) == HT_NOTFOUND) { - int found = has_backedge_to_worklist(mi, &visited, 1); + int found = has_backedge_to_worklist(mi, &visited, &stack); assert(found == 0 || found == 1); + assert(stack.len == 0); if (found == 1) { ptrhash_put(&external_mis, mi, mi); jl_array_ptr_1d_push(new_specializations, (jl_value_t*)ci); @@ -275,6 +253,7 @@ static jl_array_t *queue_external_cis(jl_array_t *list) } } htable_free(&visited); + arraylist_free(&stack); JL_GC_POP(); return new_specializations; } @@ -970,56 +949,23 @@ static void jl_verify_methods(jl_array_t *edges, jl_array_t *valids, htable_t *v } -// Propagate the result of cycle-resolution to all edges (recursively) -static int mark_edges_in_worklist(jl_array_t *edges, int idx, jl_method_instance_t *cycle, htable_t *visited, int found) -{ - jl_method_instance_t *caller = (jl_method_instance_t*)jl_array_ptr_ref(edges, idx * 2); - int oldfound = (char*)ptrhash_get(visited, caller) - (char*)HT_NOTFOUND; - if (oldfound < 3) - return 0; // not in-progress - if (!found) { - ptrhash_remove(visited, (void*)caller); - } - else { - ptrhash_put(visited, (void*)caller, (void*)((char*)HT_NOTFOUND + 1 + found)); - } - jl_array_t *callee_ids = (jl_array_t*)jl_array_ptr_ref(edges, idx * 2 + 1); - assert(jl_typeis((jl_value_t*)callee_ids, jl_array_int32_type)); - int32_t *idxs = (int32_t*)jl_array_data(callee_ids); - size_t i, badidx = 0, n = jl_array_len(callee_ids); - for (i = idxs[0] + 1; i < n; i++) { - if (mark_edges_in_worklist(edges, idxs[i], cycle, visited, found) && badidx == 0) - badidx = i - idxs[0]; - } - if (_jl_debug_method_invalidation) { - jl_value_t *loctag = NULL; - JL_GC_PUSH1(&loctag); - jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)caller); - loctag = jl_cstr_to_string("verify_methods"); - jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag); - jl_method_instance_t *callee = cycle; - if (badidx--) - callee = (jl_method_instance_t*)jl_array_ptr_ref(edges, 2 * badidx); - jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)callee); - JL_GC_POP(); - } - return 1; -} - - // Visit the entire call graph, starting from edges[idx] to determine if that method is valid -static int jl_verify_graph_edge(jl_array_t *edges, int idx, htable_t *visited, int depth) +// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable +static int jl_verify_graph_edge(jl_array_t *edges, int idx, htable_t *visited, arraylist_t *stack) { jl_method_instance_t *caller = (jl_method_instance_t*)jl_array_ptr_ref(edges, idx * 2); assert(jl_is_method_instance(caller) && jl_is_method(caller->def.method)); int found = (char*)ptrhash_get(visited, (void*)caller) - (char*)HT_NOTFOUND; if (found == 0) - return 1; // valid + return 1; // NOTFOUND == valid if (found == 1) return 0; // invalid if (found != 2) return found - 1; // depth found = 0; + jl_value_t *cause = NULL; + arraylist_push(stack, (void*)caller); + int depth = stack->len; ptrhash_put(visited, (void*)caller, (void*)((char*)HT_NOTFOUND + 3 + depth)); // change 2 to in-progress at depth jl_array_t *callee_ids = (jl_array_t*)jl_array_ptr_ref(edges, idx * 2 + 1); assert(jl_typeis((jl_value_t*)callee_ids, jl_array_int32_type)); @@ -1028,18 +974,11 @@ static int jl_verify_graph_edge(jl_array_t *edges, int idx, htable_t *visited, i size_t i, n = jl_array_len(callee_ids); for (i = idxs[0] + 1; i < n; i++) { int32_t idx = idxs[i]; - int child_found = jl_verify_graph_edge(edges, idx, visited, depth + 1); + int child_found = jl_verify_graph_edge(edges, idx, visited, stack); if (child_found == 0) { + // found what we were looking for, so terminate early found = 1; - if (_jl_debug_method_invalidation) { - jl_value_t *loctag = NULL; - JL_GC_PUSH1(&loctag); - jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)caller); - loctag = jl_cstr_to_string("verify_methods"); - jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag); - jl_array_ptr_1d_push(_jl_debug_method_invalidation, jl_array_ptr_ref(edges, idx * 2)); - JL_GC_POP(); - } + cause = jl_array_ptr_ref(edges, idx * 2); break; } else if (child_found >= 2 && child_found - 2 < cycle) { @@ -1048,24 +987,27 @@ static int jl_verify_graph_edge(jl_array_t *edges, int idx, htable_t *visited, i assert(cycle); } } - if (!found) { - if (cycle && cycle != depth) - return cycle + 2; - ptrhash_remove(visited, (void*)caller); - } - else { // found invalid - ptrhash_put(visited, (void*)caller, (void*)((char*)HT_NOTFOUND + 1 + found)); - } - if (cycle) { - // If we are the top of the current cycle, now mark all other parts of - // our cycle by re-walking the backedges graph and marking all WIP - // items as found. - // Be careful to only re-walk as far as we had originally scanned above. - // Or if we found a backedge, also mark all of the other parts of the - // cycle as also having an backedge. - n = i; - for (i = idxs[0] + 1; i < n; i++) { - mark_edges_in_worklist(edges, idxs[i], caller, visited, found); + if (!found && cycle && cycle != depth) + return cycle + 2; + // If we are the top of the current cycle, now mark all other parts of + // our cycle with what we found. + // Or if we found a backedge, also mark all of the other parts of the + // cycle as also having an backedge. + while (stack->len >= depth) { + void *mi = arraylist_pop(stack); + assert((char*)ptrhash_get(visited, mi) - (char*)HT_NOTFOUND == 4 + stack->len); + if (found) + ptrhash_put(visited, mi, (void*)((char*)HT_NOTFOUND + 1 + found)); + else + ptrhash_remove(visited, mi); // assign as NOTFOUND in table + if (_jl_debug_method_invalidation && found) { + jl_value_t *loctag = NULL; + JL_GC_PUSH1(&loctag); + jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)mi); + loctag = jl_cstr_to_string("verify_methods"); + jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag); + jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)cause); + JL_GC_POP(); } } return found ? 0 : 1; @@ -1074,13 +1016,15 @@ static int jl_verify_graph_edge(jl_array_t *edges, int idx, htable_t *visited, i // Visit all entries in edges, verify if they are valid static jl_array_t *jl_verify_graph(jl_array_t *edges, htable_t *visited) { + arraylist_t stack; + arraylist_new(&stack, 0); size_t i, n = jl_array_len(edges) / 2; jl_array_t *valids = jl_alloc_array_1d(jl_array_uint8_type, n); JL_GC_PUSH1(&valids); int8_t *valids_data = (int8_t*)jl_array_data(valids); - for (i = 0; i < n; i++) { - valids_data[i] = jl_verify_graph_edge(edges, i, visited, 1); - } + for (i = 0; i < n; i++) + valids_data[i] = jl_verify_graph_edge(edges, i, visited, &stack); + arraylist_free(&stack); JL_GC_POP(); return valids; } @@ -1096,8 +1040,8 @@ static void jl_insert_backedges(jl_array_t *edges, jl_array_t *ext_targets, jl_a JL_GC_PUSH1(&valids); htable_t visited; htable_new(&visited, 0); - jl_verify_methods(edges, valids, &visited); - valids = jl_verify_graph(edges, &visited); + jl_verify_methods(edges, valids, &visited); // consumes valids, creates visited + valids = jl_verify_graph(edges, &visited); // consumes visited, creates valids size_t i, l = jl_array_len(edges) / 2; // next build a map from external MethodInstances to their CodeInstance for insertion