Skip to content

Commit

Permalink
Better handling of union of atoms in maps:get
Browse files Browse the repository at this point in the history
Summary: Collect possible atom literals for the key in custom elaboration of `maps:get` and union the associated values.

Reviewed By: ilya-klyuchnikov

Differential Revision: D54419232

fbshipit-source-id: c7d3934000726885d45f3d5e0d32580a615fbc56
  • Loading branch information
VLanvin authored and facebook-github-bot committed Mar 1, 2024
1 parent ef8dbf3 commit cb8e23b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,12 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
if (!subtype.subType(mapTy, anyMapTy))
throw ExpectedSubtype(map.pos, map, expected = anyMapTy, got = mapTy)
val mapType = narrow.asMapType(mapTy)
keyTy match {
case AtomLitType(key) =>
val valTy = narrow.getValType(key, mapType)
(valTy, env1)
case _ =>
val atomKeys = narrow.asAtomLits(keyTy)
atomKeys match {
case Some(atoms) =>
val valTys = atoms.map(narrow.getValType(_, mapType))
(subtype.join(valTys), env1)
case None =>
val valTy = narrow.getValType(mapType)
(valTy, env1)
}
Expand Down Expand Up @@ -334,11 +335,12 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
if (!subtype.subType(mapTy, anyMapTy))
throw ExpectedSubtype(map.pos, map, expected = anyMapTy, got = mapTy)
val mapType = narrow.asMapType(mapTy)
keyTy match {
case AtomLitType(key) =>
val valTy = narrow.getValType(key, mapType)
(subtype.join(valTy, defaultValTy), env1)
case _ =>
val atomKeys = narrow.asAtomLits(keyTy)
atomKeys match {
case Some(atoms) =>
val valTys = atoms.map(narrow.getValType(_, mapType))
(subtype.join(valTys + defaultValTy), env1)
case None =>
val valTy = narrow.getValType(mapType)
(subtype.join(valTy, defaultValTy), env1)
}
Expand Down
16 changes: 16 additions & 0 deletions eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,20 @@ class Narrow(pipelineContext: PipelineContext) {
throw new IllegalStateException()
}
}

// Recursion is sound since we don't unfold under constructors
def asAtomLits(t: Type): Option[Set[String]] =
t match {
case AtomLitType(s) => Some(Set(s))
case BoundedDynamicType(bound) =>
asAtomLits(bound)
case UnionType(ts) =>
ts.foldLeft[Option[Set[String]]](Some(Set())) { (acc, ty) =>
acc.flatMap(atoms => asAtomLits(ty).map(atoms2 => atoms ++ atoms2))
}
case RemoteType(rid, args) =>
val body = util.getTypeDeclBody(rid, args)
asAtomLits(body)
case _ => None
}
}

0 comments on commit cb8e23b

Please sign in to comment.