Skip to content

Commit

Permalink
Port POSet to Java (#4103)
Browse files Browse the repository at this point in the history
Part of #4030.

Port the `POSet` class from Scala to Java. 

This is mostly a straightforward conversion from Scala functional idioms
(`map`, `filter`, `reduce`, etc.) to the corresponding Java Stream
methods. To port `lazy val`s, we also implement a `Lazy<T>` wrapper
which caches the result of a `Supplier<T>`.

The only remaining reference to Scala here is a single constructor which
takes a `scala.collection.Set<Tuple2<T, T>>` and internally converts it
to a `java.util.Set<Pair<T, T>>`. This can be removed once we more
pervasively switch away from Scala collection types everywhere in the
codebase.

---------

Co-authored-by: Bruce Collie <brucecollie82@gmail.com>
  • Loading branch information
Scott-Guest and Baltoli committed Mar 18, 2024
1 parent 4f1b4e2 commit 2dbd10c
Show file tree
Hide file tree
Showing 16 changed files with 463 additions and 293 deletions.
@@ -1,2 +1,2 @@
[Error] Compiler: Had 1 parsing errors.
[Error] Compiler: Illegal circular relation: Exp < ExpList < Exp
[Error] Compiler: Illegal circular relation: ExpList < Exp < ExpList
37 changes: 11 additions & 26 deletions kernel/src/main/java/org/kframework/backend/kore/ModuleToKORE.java
Expand Up @@ -237,12 +237,8 @@ public void convert(

semantics.append("\n// symbols\n");
Set<Production> overloads = new HashSet<>();
for (Production lesser : iterable(module.overloads().elements())) {
for (Production greater :
iterable(
module.overloads().relations().get(lesser).getOrElse(Collections::<Production>Set))) {
overloads.add(greater);
}
for (Production lesser : module.overloads().elements()) {
overloads.addAll(module.overloads().relations().getOrDefault(lesser, Set.of()));
}
translateSymbols(attributes, functionRules, overloads, semantics);

Expand Down Expand Up @@ -270,14 +266,8 @@ public void convert(
}
}

for (Production lesser : iterable(module.overloads().elements())) {
for (Production greater :
iterable(
module
.overloads()
.relations()
.get(lesser)
.getOrElse(() -> Collections.<Production>Set()))) {
for (Production lesser : module.overloads().elements()) {
for (Production greater : module.overloads().relations().getOrDefault(lesser, Set.of())) {
genOverloadedAxiom(lesser, greater, syntax);
}
}
Expand Down Expand Up @@ -321,14 +311,8 @@ public void convert(
genNoJunkAxiom(sort, semantics);
}

for (Production lesser : iterable(module.overloads().elements())) {
for (Production greater :
iterable(
module
.overloads()
.relations()
.get(lesser)
.getOrElse(() -> Collections.<Production>Set()))) {
for (Production lesser : module.overloads().elements()) {
for (Production greater : module.overloads().relations().getOrDefault(lesser, Set.of())) {
genOverloadedAxiom(lesser, greater, semantics);
}
}
Expand Down Expand Up @@ -1741,10 +1725,11 @@ private Att addKoreAttributes(
att = att.add(Att.TERMINALS(), sb.toString());
if (prod.klabel().isDefined()) {
List<K> lessThanK = new ArrayList<>();
Option<scala.collection.Set<Tag>> lessThan =
module.priorities().relations().get(Tag(prod.klabel().get().name()));
if (lessThan.isDefined()) {
for (Tag t : iterable(lessThan.get())) {
Optional<Set<Tag>> lessThan =
Optional.ofNullable(
module.priorities().relations().get(Tag(prod.klabel().get().name())));
if (lessThan.isPresent()) {
for (Tag t : lessThan.get()) {
if (ConstructorChecks.isBuiltinLabel(KLabel(t.name()))) {
continue;
}
Expand Down
Expand Up @@ -28,7 +28,7 @@ public ResolveOverloadedTerminators(POSet<Production> overloads) {
public Either<Set<KEMException>, Term> apply(TermCons tc) {
if (overloads.elements().contains(tc.production()) && tc.items().isEmpty()) {
Set<Production> candidates =
stream(overloads.elements())
streamIter(overloads.elements())
.filter(
p ->
p.klabel().isDefined()
Expand Down
Expand Up @@ -19,6 +19,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.kframework.Collections;
import org.kframework.POSet;
import org.kframework.attributes.Att;
Expand All @@ -39,9 +40,7 @@
import org.kframework.parser.inner.RuleGrammarGenerator;
import org.kframework.utils.OS;
import org.kframework.utils.errorsystem.KEMException;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Set;

/**
* Class to manage communication with z3 for the purposes of type inference. This class is driven by
Expand Down Expand Up @@ -166,27 +165,28 @@ private void makeSubsorts(Module mod, String name, POSet<Sort> relations) {
Map<SortHead, Integer> ordinals = new HashMap<>();
int i = 0;

for (Sort s : iterable(relations.sortedElements())) {
for (Sort s : relations.sortedElements()) {
if (!isRealSort(s.head())) {
continue;
}
ordinals.put(s.head(), i++);
}
// provide fixed interpretation of subsort relation
println("(define-fun " + name + " ((s1 Sort) (s2 Sort)) Bool (or");
for (Tuple2<Sort, Set<Sort>> relation :
stream(relations.relations())
.sorted(Comparator.comparing(t -> -ordinals.getOrDefault(t._1().head(), 0)))
for (Pair<Sort, java.util.Set<Sort>> relation :
relations.relations().entrySet().stream()
.map(t -> Pair.of(t.getKey(), t.getValue()))
.sorted(Comparator.comparing(t -> -ordinals.getOrDefault(t.getLeft().head(), 0)))
.toList()) {
if (!isRealSort(relation._1().head())) {
if (!isRealSort(relation.getLeft().head())) {
continue;
}
for (Sort s2 : iterable(relation._2())) {
for (Sort s2 : relation.getRight()) {
if (!isRealSort(s2.head())) {
continue;
}
print(" (and (= s1 ");
printSort(relation._1());
printSort(relation.getLeft());
print(") (= s2 ");
printSort(s2);
println("))");
Expand Down
Expand Up @@ -53,12 +53,7 @@ private static void computeSide(
MutableInt nextOrdinal) {
NonTerminal nt = (NonTerminal) items.get(idx);
Tag parent = new Tag(prod.klabel().get().name());
Set<Tag> prods = new HashSet<>();
for (Tag child :
iterable(
module.priorities().relations().get(parent).getOrElse(() -> Collections.<Tag>Set()))) {
prods.add(child);
}
Set<Tag> prods = new HashSet<>(module.priorities().relations().getOrDefault(parent, Set.of()));
for (Tuple2<Tag, Tag> entry : iterable(assoc)) {
if (entry._1().equals(parent)) {
prods.add(entry._2());
Expand Down Expand Up @@ -322,7 +317,7 @@ private static void appendOverloadCondition(
.append(nts.get(i))
.append(".nterm->symbol, \"inj{\", 4) == 0 && (false");
Sort greaterSort = lesser.nonterminals().apply(i).sort();
for (Sort lesserSort : iterable(module.subsorts().elements())) {
for (Sort lesserSort : module.subsorts().elements()) {
if (module.subsorts().lessThanEq(lesserSort, greaterSort)) {
bison.append(" || strcmp($").append(nts.get(i)).append(".nterm->children[0]->sort, \"");
encodeKore(lesserSort, bison);
Expand All @@ -341,7 +336,7 @@ private static void appendOverloadChecks(
Production greater,
List<Integer> nts,
boolean hasLocation) {
for (Production lesser : iterable(disambModule.overloads().sortedElements())) {
for (Production lesser : disambModule.overloads().sortedElements()) {
if (disambModule.overloads().lessThan(lesser, greater)) {
bison.append(" if (");
appendOverloadCondition(bison, module, greater, lesser, nts);
Expand Down

0 comments on commit 2dbd10c

Please sign in to comment.