Skip to content

Commit

Permalink
Experiment using desugared to rewrite xor
Browse files Browse the repository at this point in the history
  • Loading branch information
olafurpg committed Dec 13, 2016
1 parent 13935bf commit 445b2cd
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 25 deletions.
28 changes: 16 additions & 12 deletions core/src/main/scala/scalafix/rewrite/Xor2Either.scala
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
package scalafix.rewrite

import scala.meta.parsers.Parse
import scala.{meta => m}
import scalafix.util.Patch
import scalafix.util.Whitespace
import scalafix.util.logger

case object Xor2Either extends Rewrite {
override def rewrite(ast: m.Tree, ctx: RewriteCtx): Seq[Patch] = {
import scala.meta._
val semantic = getSemanticApi(ctx)
class Desugared[T <: Tree: Parse] {
def unapply(original: T): Option[T] = semantic.desugared(original)
}
object DType extends Desugared[Type]
object DTerm extends Desugared[Term]
// NOTE. This approach is super inefficient, since we run semantic.desugar on
// every case for every node in the tree. Ideally, we first match on the
// syntax structure we want and then run semantic.desugar.
ast.collect {
case t: m.Type.Ref
if semantic
.desugared(t: m.Type)(m.parsers.Parse.parseType)
.exists(_.syntax.contains("Xor")) =>
logger.elem(t.syntax, semantic.desugared(t: Type))
val tok = t.tokens.head
Seq(
Patch(tok, tok, tok.syntax + "BANANA")
)
}.flatten
case t @ DType(t"cats.data.Xor") =>
Patch(t.tokens.head, t.tokens.last, s"Either")
case t @ DTerm(q"cats.data.Xor.Right.apply[..$_]") =>
Patch(t.tokens.head, t.tokens.last, s"Right")
case t @ DTerm(q"cats.data.Xor.Left.apply[..$_]") =>
Patch(t.tokens.head, t.tokens.last, s"Left")
}
}
}
11 changes: 8 additions & 3 deletions core/src/test/resources/Xor/basic.source
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
rewrites = [Xor2Either]
<<< ONLY xor 1
<<< xor 1
import cats.data.Xor
trait A {
val r: Xor[Int, String]
val r: Xor[Int, String] = Xor.Right("")
val s: Xor[Int, String] = Xor.Left(1 /* comment */)
val nest: Seq[Xor[Int, cats.data.Xor[String, Int]]]
}
>>>
import cats.data.Xor
trait A {
val r: Either[Int, String]
val r: Either[Int, String] = Right("")
val s: Either[Int, String] = Left(1 /* comment */)
val nest: Seq[Either[Int, Either[String, Int]]]
}
27 changes: 18 additions & 9 deletions scalafix-nsc/src/main/scala/scalafix/nsc/NscSemanticApi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ import scalafix.util.logger
case class SemanticContext(enclosingPackage: String, inScope: List[String])

trait NscSemanticApi extends ReflectToolkit {
implicit class XtensionPosition(gpos: scala.reflect.internal.util.Position) {
def matches(mpos: m.Position): Boolean = {
gpos.isDefined &&
gpos.start == mpos.start.offset &&
gpos.end == mpos.end.offset
}
}

/** Returns a map from byte offset to type name at that offset. */
private def offsetToType(gtree: g.Tree,
Expand Down Expand Up @@ -55,7 +62,7 @@ trait NscSemanticApi extends ReflectToolkit {

val parsed = dialect(gtree.toString()).parse[m.Type]
parsed match {
case m.Parsed.Success(ast) =>
case m.parsers.Parsed.Success(ast) =>
builder(gtree.pos.point) = cleanUp(ast)
case _ =>
}
Expand Down Expand Up @@ -105,7 +112,12 @@ trait NscSemanticApi extends ReflectToolkit {
val f = pf.lift
def iter(gtree: g.Tree): Unit = {
f(gtree).foreach(builder += _)
gtree.children.foreach(iter)
gtree match {
case t @ g.TypeTree() if t.original != null && t.original.nonEmpty =>
iter(t.original)
case _ =>
gtree.children.foreach(iter)
}
}
iter(gtree)
builder.result()
Expand All @@ -128,19 +140,16 @@ trait NscSemanticApi extends ReflectToolkit {

override def desugared[T <: Tree](tree: T)(
implicit parse: Parse[T]): Option[T] = {
logger.elem(tree, unit.body)
val result = collect[Option[T]](unit.body) {
case t
if t.pos.isDefined &&
t.pos.start == tree.pos.start.offset =>
// case t if { logger.elem(t.toString(), g.showRaw(t)); false } => None
case t if t.pos.matches(tree.pos) =>
import scala.meta._
parse(m.Input.String(t.toString()), config.dialect) match {
case Parsed.Success(x) => Some(x)
case m.parsers.Parsed.Success(x) => Some(x)
case _ => None
}
case t if { logger.elem(t.toString(), g.showRaw(t)); false } => None
}.flatten
logger.elem(result)
// logger.elem(result)
result.headOption
}
}
Expand Down
2 changes: 2 additions & 0 deletions scalafix-nsc/src/test/scala/cats/data/Xor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package cats.data
sealed abstract class Xor[+A, +B] extends Product with Serializable

object Xor {
def left[A, B](a: A): A Xor B = Xor.Left(a)
def right[A, B](b: B): A Xor B = Xor.Right(b)
final case class Left[+A](a: A) extends (A Xor Nothing)
final case class Right[+B](b: B) extends (Nothing Xor B)
}
12 changes: 11 additions & 1 deletion scalafix-nsc/src/test/scala/scalafix/SemanticTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ class SemanticTests extends FunSuite {
sys.error(s"ReflectToMeta initialization failed: $msg")
val classpath = System.getProperty("sbt.paths.scalafixNsc.test.classes")
val pluginpath = System.getProperty("sbt.paths.plugin.jar")
val options = "-cp " + classpath + " -Xplugin:" + pluginpath + ":" + classpath + " -Xplugin-require:scalafix"
val scalacOptions = Seq[String](
"-Yrangepos" // necessary to match reflect positions with meta positions.
)
val options =
"-cp " + classpath +
" -Xplugin:" + pluginpath + ":" + classpath +
" -Xplugin-require:scalafix" +
scalacOptions.mkString(" ", " ", " ")

val args = CommandLineParser.tokenize(options)
val emptySettings = new Settings(
error => fail(s"couldn't apply settings because $error"))
Expand Down Expand Up @@ -142,10 +150,12 @@ class SemanticTests extends FunSuite {

def check(original: String, expectedStr: String, diffTest: DiffTest): Unit = {
val fixed = fix(wrap(original, diffTest), diffTest.config)
// logger.elem(fixed)
val obtained = parse(fixed)
val expected = parse(expectedStr)
try {
checkMismatchesModuloDesugarings(obtained, expected)

if (!diffTest.noWrap) typeChecks(wrap(fixed, diffTest))
} catch {
case MismatchException(details) =>
Expand Down

0 comments on commit 445b2cd

Please sign in to comment.