From fca3e0378c5b274806b48be0479bcd6610d248c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=93lafur=20P=C3=A1ll=20Geirsson?= Date: Mon, 12 Dec 2016 18:33:15 +0100 Subject: [PATCH] cats.data.Xor/XorT to Either/EitherT rewrite (+7 squashed commits) Squashed commits: [b5aaec5] Pushing a few experiments for feedback. [445b2cd] Experiment using desugared to rewrite xor [13935bf] WIP [40caaa3] Pick up configuration in sbt plugin. [ab34355] Use build-info for friendlier intellij setup [a5b666a] Setup hocon configuration. [c94fef1] Cross-build to 2.12 (#28) * Upgrade scalameta and cross-build to 2.12 * Gracefully handle 2.11 and 2.12 in sbt plugin. * Don't rely on sbt.ivy.home. --- .../main/scala/scalafix/rewrite/Rewrite.scala | 2 +- .../scala/scalafix/rewrite/SemanticApi.scala | 13 ++++- .../scala/scalafix/rewrite/Xor2Either.scala | 53 +++++++++++++++++++ .../main/scala/scalafix/util/AddImport.scala | 37 +++++++++++++ .../main/scala/scalafix/util/AnyDiff.scala | 45 ++++++++++++++++ .../scala/scalafix/util/ChangeMethod.scala | 53 +++++++++++++++++++ .../main/scala/scalafix/util/ChangeType.scala | 39 ++++++++++++++ core/src/main/scala/scalafix/util/Patch.scala | 20 +++++++ .../scalafix/util/StructurallyEqual.scala | 42 +++++++++++++++ .../scala/scalafix/util/syntax/package.scala | 25 +++++++++ core/src/test/resources/Xor/basic.source | 36 +++++++++++++ .../scala/scalafix/nsc/NscSemanticApi.scala | 43 ++++++++++++++- .../src/test/scala/cats/data/Xor.scala | 17 ++++++ .../test/scala/cats/implicits/package.scala | 9 ++++ 14 files changed, 430 insertions(+), 4 deletions(-) create mode 100644 core/src/main/scala/scalafix/rewrite/Xor2Either.scala create mode 100644 core/src/main/scala/scalafix/util/AddImport.scala create mode 100644 core/src/main/scala/scalafix/util/AnyDiff.scala create mode 100644 core/src/main/scala/scalafix/util/ChangeMethod.scala create mode 100644 core/src/main/scala/scalafix/util/ChangeType.scala create mode 100644 core/src/main/scala/scalafix/util/StructurallyEqual.scala create mode 100644 core/src/main/scala/scalafix/util/syntax/package.scala create mode 100644 core/src/test/resources/Xor/basic.source create mode 100644 scalafix-nsc/src/test/scala/cats/data/Xor.scala create mode 100644 scalafix-nsc/src/test/scala/cats/implicits/package.scala diff --git a/core/src/main/scala/scalafix/rewrite/Rewrite.scala b/core/src/main/scala/scalafix/rewrite/Rewrite.scala index e7ad73e21..5cf320285 100644 --- a/core/src/main/scala/scalafix/rewrite/Rewrite.scala +++ b/core/src/main/scala/scalafix/rewrite/Rewrite.scala @@ -17,7 +17,7 @@ object Rewrite { } val syntaxRewrites: Seq[Rewrite] = Seq(ProcedureSyntax, VolatileLazyVal) - val semanticRewrites: Seq[Rewrite] = Seq(ExplicitImplicit) + val semanticRewrites: Seq[Rewrite] = Seq(ExplicitImplicit, Xor2Either) val allRewrites: Seq[Rewrite] = syntaxRewrites ++ semanticRewrites val defaultRewrites: Seq[Rewrite] = allRewrites.filterNot(_ == VolatileLazyVal) diff --git a/core/src/main/scala/scalafix/rewrite/SemanticApi.scala b/core/src/main/scala/scalafix/rewrite/SemanticApi.scala index dfd0b5262..1f8181ca1 100644 --- a/core/src/main/scala/scalafix/rewrite/SemanticApi.scala +++ b/core/src/main/scala/scalafix/rewrite/SemanticApi.scala @@ -1,7 +1,7 @@ package scalafix.rewrite -import scala.meta.Defn -import scala.meta.Type +import scala.meta.{Defn, Term, Tree, Type} +import scala.meta.parsers.Parse /** A custom semantic api for scalafix rewrites. * @@ -18,4 +18,13 @@ trait SemanticApi { /** Returns the type annotation for given val/def. */ def typeSignature(defn: Defn): Option[Type] + + def desugared[A <: Tree](tree: A)(implicit parse: Parse[A]): Option[A] + + class Desugared[T <: Tree: Parse] { + def unapply(original: T): Option[T] = desugared(original) + } + + object DType extends Desugared[Type] + object DTerm extends Desugared[Term] } diff --git a/core/src/main/scala/scalafix/rewrite/Xor2Either.scala b/core/src/main/scala/scalafix/rewrite/Xor2Either.scala new file mode 100644 index 000000000..87671b041 --- /dev/null +++ b/core/src/main/scala/scalafix/rewrite/Xor2Either.scala @@ -0,0 +1,53 @@ +package scalafix.rewrite + +import scala.collection.immutable.Seq +import scala.{meta => m} +import scalafix.util._ +import scala.meta._ + +case object Xor2Either extends Rewrite { + override def rewrite(ast: m.Tree, ctx: RewriteCtx): Seq[Patch] = { + implicit val semanticApi: SemanticApi = getSemanticApi(ctx) + + //Create a sequence of type replacements + val replacementTypes = List( + ReplaceType(t"cats.data.XorT", t"cats.data.EitherT", "EitherT"), + ReplaceType(t"cats.data.Xor", t"scala.util.Either", "Either"), + ReplaceType(t"cats.data.Xor.Left", t"scala.util.Left", "Left"), + ReplaceType(t"cats.data.Xor.Right", t"scala.util.Either.Right", "Right") + ) + + //Add in some method replacements + val replacementTerms = List( + ReplaceTerm(q"cats.data.Xor.Right.apply", + q"scala.util.Right.apply", + q"scala.util"), + ReplaceTerm(q"cats.data.Xor.Left.apply", + q"scala.util.Left.apply", + q"scala.util") + ) + + //Then add needed imports. + //todo - derive this from patches created, types + //and terms replaced + //Only add if they are not already imported + val additionalImports = List( + "cats.data.EitherT", + "cats.implicits._", + "scala.util.Either" + ) + + val typeReplacements = + new ChangeType(ast).gatherPatches(replacementTypes) + + val termReplacements = + new ChangeMethod(ast).gatherPatches(replacementTerms) + + //Make this additional imports, beyond what can be derived from the types + val addedImports = + if (typeReplacements.isEmpty && termReplacements.isEmpty) Seq[Patch]() + else new AddImport(ast).gatherPatches(additionalImports) + + addedImports ++ typeReplacements ++ termReplacements + } +} diff --git a/core/src/main/scala/scalafix/util/AddImport.scala b/core/src/main/scala/scalafix/util/AddImport.scala new file mode 100644 index 000000000..470a4b46c --- /dev/null +++ b/core/src/main/scala/scalafix/util/AddImport.scala @@ -0,0 +1,37 @@ +package scalafix.util + +import scala.collection.immutable.Seq +import scala.{meta => m} +import scala.meta._ +import scalafix.rewrite._ + +class AddImport(ast: m.Tree)(implicit sApi: SemanticApi) { + val allImports = ast.collect { + case t @ q"import ..$importersnel" => t -> importersnel + } + + val firstImport = allImports.headOption + val firstImportFirstToken = firstImport.flatMap { + case (importStatement, _) => importStatement.tokens.headOption + } + val tokenBeforeFirstImport = firstImportFirstToken.flatMap { stopAt => + ast.tokens.takeWhile(_ != stopAt).lastOption + } + + //This is currently a very dumb implementation. + //It does no checking for existing imports and makes + //no attempt to consolidate imports + def addedImports(importString: String): Seq[Patch] = + tokenBeforeFirstImport + .map( + beginImportsLocation => + Patch + .insertAfter(beginImportsLocation, importString) + ) + .toList + + def gatherPatches(imports: Seq[String]): Seq[Patch] = { + val importStrings = imports.map("import " + _).mkString("\n", "\n", "\n") + addedImports(importStrings) + } +} diff --git a/core/src/main/scala/scalafix/util/AnyDiff.scala b/core/src/main/scala/scalafix/util/AnyDiff.scala new file mode 100644 index 000000000..8ce724c1b --- /dev/null +++ b/core/src/main/scala/scalafix/util/AnyDiff.scala @@ -0,0 +1,45 @@ +package scalafix.util + +import scala.collection.immutable.Seq +import scala.meta.Tree + +/** Helper class to create textual diff between two objects */ +case class AnyDiff(a: Any, b: Any) extends Exception { + override def toString: String = s"""$a != $b $mismatchClass""" + def detailed: String = compare(a, b) + + /** Best effort attempt to find a line number for scala.meta.Tree */ + def lineNumber: Int = + 1 + (a match { + case e: Tree => e.pos.start.line + case Some(t: Tree) => t.pos.start.line + case lst: Seq[_] => + lst match { + case (head: Tree) :: tail => head.pos.start.line + case _ => -2 + } + case _ => -2 + }) + def mismatchClass: String = + if (clsName(a) != clsName(b)) s"(${clsName(a)} != ${clsName(b)})" + else s"same class ${clsName(a)}" + + private def clsName(a: Any) = a.getClass.getName + + private def compare(a: Any, b: Any): String = + (a, b) match { + case (t1: Tree, t2: Tree) => + s"""$toString + |Syntax diff: + |${t1.syntax} + |${t2.syntax} + | + |Structure diff: + |${t1.structure} + |${t2.structure} + """.stripMargin + case (t1: Seq[_], t2: Seq[_]) => + t1.zip(t2).map { case (a, b) => compare(a, b) }.mkString + case _ => toString + } +} diff --git a/core/src/main/scala/scalafix/util/ChangeMethod.scala b/core/src/main/scala/scalafix/util/ChangeMethod.scala new file mode 100644 index 000000000..ad5db73f3 --- /dev/null +++ b/core/src/main/scala/scalafix/util/ChangeMethod.scala @@ -0,0 +1,53 @@ +package scalafix.util + +import scala.collection.immutable.Seq +import scala.{meta => m} +import scala.meta._ +import scalafix.rewrite._ +import syntax._ + +case class ReplaceTerm( + original: m.Term, + replacement: m.Term, + importSegment: m.Term +) { + require(replacement.syntax.contains(importSegment.syntax), + "Specify the portion of the new Term to be imported") +} + +class ChangeMethod(ast: m.Tree)(implicit sApi: SemanticApi) { + + import sApi._ + + def partialTermMatch(rt: ReplaceTerm) + : PartialFunction[m.Tree, (scala.meta.Term, ReplaceTerm)] = { + // only desugar selections + case t @ q"$expr.$name" & DTerm(desugared) + if desugared.termNames.map(_.syntax) == rt.original.termNames.map( + _.syntax) => + //I am not sure how to get rid of this asInstanceOf + //But it should be safe because I have matched on DTerm above + t.asInstanceOf[m.Term] -> rt + } + + def partialTermMatches(replacementTerms: Seq[ReplaceTerm]) + : PartialFunction[m.Tree, (m.Term, ReplaceTerm)] = + replacementTerms.map(partialTermMatch).reduce(_ orElse _) + + def terms(ptm: PartialFunction[m.Tree, (m.Term, ReplaceTerm)]) = + ast.collect(ptm) + + def termReplacements(trms: Seq[(m.Term, ReplaceTerm)]): Seq[Patch] = + trms.map { + case (t, ReplaceTerm(oldTerm, newTerm, imported)) => + val replacement = newTerm.termNames + .map(_.syntax) + .diff(imported.termNames.map(_.syntax)) //Strip off imported portion + .filterNot(_ == "apply") //special handling for apply, suppress as it will be inferred + .mkString(".") + Patch.replace(t, replacement) + } + + def gatherPatches(tr: Seq[ReplaceTerm]): Seq[Patch] = + termReplacements(terms(partialTermMatches(tr))) +} diff --git a/core/src/main/scala/scalafix/util/ChangeType.scala b/core/src/main/scala/scalafix/util/ChangeType.scala new file mode 100644 index 000000000..3b9921368 --- /dev/null +++ b/core/src/main/scala/scalafix/util/ChangeType.scala @@ -0,0 +1,39 @@ +package scalafix.util + +import scala.collection.immutable.Seq +import scala.{meta => m} +import scalafix.rewrite._ + +//Provide a little structure to the replacements we will be performing +case class ReplaceType(original: m.Type, + replacement: m.Type, + newString: String) { + def toPatch(t: m.Type): Patch = Patch.replace(t, newString) +} + +class ChangeType(ast: m.Tree)(implicit sApi: SemanticApi) { + import sApi._ + + def partialTypeMatch( + rt: ReplaceType): PartialFunction[m.Tree, (m.Type, ReplaceType)] = { + case t @ DType(desugared) + if StructurallyEqual(desugared, rt.original).isRight => + t -> rt + } + + def partialTypeMatches(replacementTypes: Seq[ReplaceType]) + : PartialFunction[m.Tree, (m.Type, ReplaceType)] = + replacementTypes.map(partialTypeMatch).reduce(_ orElse _) + + def tpes(ptm: PartialFunction[m.Tree, (m.Type, ReplaceType)]) + : Seq[(m.Type, ReplaceType)] = ast.collect { ptm } + + //This is unsafe, come up with something better + def typeReplacements(tpes: Seq[(m.Type, ReplaceType)]): Seq[Patch] = + tpes.map { + case (tree, rt) => rt.toPatch(tree) + } + + def gatherPatches(tr: Seq[ReplaceType]): Seq[Patch] = + typeReplacements(tpes(partialTypeMatches(tr))) +} diff --git a/core/src/main/scala/scalafix/util/Patch.scala b/core/src/main/scala/scalafix/util/Patch.scala index 6be15ae2a..392c30b3d 100644 --- a/core/src/main/scala/scalafix/util/Patch.scala +++ b/core/src/main/scala/scalafix/util/Patch.scala @@ -35,4 +35,24 @@ object Patch { .map(_.syntax) .mkString("") } + + def replace(token: Token, replacement: String): Patch = + Patch(token, token, replacement) + + def replace(tree: Tree, replacement: String): Patch = + Patch(tree.tokens.head, tree.tokens.last, replacement) + + def insertBefore(token: Token, toPrepend: String) = + replace(token, s"$toPrepend${token.syntax}") + + def insertBefore(tree: Tree, toPrepend: String): Patch = + replace(tree, s"$toPrepend${tree.syntax}") + + def insertAfter(token: Token, toAppend: String) = + replace(token, s"$toAppend${token.syntax}") + + def insertAfter(tree: Tree, toAppend: String): Patch = + replace(tree, s"${tree.syntax}$toAppend") + + def delete(tree: Tree): Patch = replace(tree, "") } diff --git a/core/src/main/scala/scalafix/util/StructurallyEqual.scala b/core/src/main/scala/scalafix/util/StructurallyEqual.scala new file mode 100644 index 000000000..4fb1147d1 --- /dev/null +++ b/core/src/main/scala/scalafix/util/StructurallyEqual.scala @@ -0,0 +1,42 @@ +package scalafix.util + +import scala.collection.immutable.Seq + +object StructurallyEqual { + import scala.meta.Tree + + /** Test if two trees are structurally equal. + * @return Left(errorMessage with minimal diff) if trees are not structurally + * different, otherwise Right(Unit). To convert into exception with + * meaningful error message, + * val Right(_) = StructurallyEqual(a, b) + **/ + def apply(a: Tree, b: Tree): Either[AnyDiff, Unit] = { + def loop(x: Any, y: Any): Boolean = { + val ok: Boolean = (x, y) match { + case (x, y) if x == null || y == null => x == null && y == null + case (x: Some[_], y: Some[_]) => loop(x.get, y.get) + case (x: None.type, y: None.type) => true + case (xs: Seq[_], ys: Seq[_]) => + xs.length == ys.length && + xs.zip(ys).forall { + case (x, y) => loop(x, y) + } + case (x: Tree, y: Tree) => + def sameStructure = + x.productPrefix == y.productPrefix && + loop(x.productIterator.toList, y.productIterator.toList) + sameStructure + case _ => x == y + } + if (!ok) throw AnyDiff(x, y) + else true + } + try { + loop(a, b) + Right(Unit) + } catch { + case t: AnyDiff => Left(t) + } + } +} diff --git a/core/src/main/scala/scalafix/util/syntax/package.scala b/core/src/main/scala/scalafix/util/syntax/package.scala new file mode 100644 index 000000000..0f4bacbd0 --- /dev/null +++ b/core/src/main/scala/scalafix/util/syntax/package.scala @@ -0,0 +1,25 @@ +package scalafix.util + +import scala.{meta => m} +import scala.meta._ + +package object syntax { + + //Allow two patterns to be combined. Contributed by @nafg + object & { def unapply[A](a: A) = Some((a, a)) } + + implicit class MetaOps(from: m.Tree) { + def termNames: List[Term.Name] = { + from collect { + case t: Term.Name => t + } + } + + def typeNames: List[Type.Name] = { + from collect { + case t: Type.Name => t + } + } + } + +} diff --git a/core/src/test/resources/Xor/basic.source b/core/src/test/resources/Xor/basic.source new file mode 100644 index 000000000..e83f2e869 --- /dev/null +++ b/core/src/test/resources/Xor/basic.source @@ -0,0 +1,36 @@ +rewrites = [Xor2Either] +<<< xor 1 +import scala.concurrent.Future +import cats.data.{ Xor, XorT } +trait A { +type MyDisjunction = Xor[Int, String] + val r: MyDisjunction = Xor.Right.apply("") + val s: Xor[Int, String] = cats.data.Xor.Left(1 /* comment */) + val t: Xor[Int, String] = r.map(_ + "!") + val nest: Seq[Xor[Int, cats.data.Xor[String, Int]]] + val u: XorT[Future, Int, String] = ??? +} +>>> +import cats.data.EitherT +import cats.implicits._ +import scala.util.Either +import scala.concurrent.Future +import cats.data.{ Xor, XorT } +trait A { +type MyDisjunction = Either[Int, String] + val r: MyDisjunction = Right("") + val s: Either[Int, String] = Left(1 /* comment */) + val t: Either[Int, String] = r.map(_ + "!") + val nest: Seq[Either[Int, Either[String, Int]]] + val u: EitherT[Future, Int, String] = ??? +} +<<< xor do not modify when not present +import scala.concurrent.Future +trait A { + val num = 1 +} +>>> +import scala.concurrent.Future +trait A { + val num = 1 +} diff --git a/scalafix-nsc/src/main/scala/scalafix/nsc/NscSemanticApi.scala b/scalafix-nsc/src/main/scala/scalafix/nsc/NscSemanticApi.scala index 0a849dc51..c67ffeb24 100644 --- a/scalafix-nsc/src/main/scala/scalafix/nsc/NscSemanticApi.scala +++ b/scalafix-nsc/src/main/scala/scalafix/nsc/NscSemanticApi.scala @@ -2,7 +2,9 @@ package scalafix.nsc import scala.collection.mutable import scala.meta.Dialect +import scala.meta.Tree import scala.meta.Type +import scala.meta.parsers.Parse import scala.reflect.internal.util.SourceFile import scala.{meta => m} import scalafix.Fixed @@ -14,6 +16,13 @@ import scalafix.util.logger case class SemanticContext(enclosingPackage: String, inScope: List[String]) trait NscSemanticApi extends ReflectToolkit { + implicit class XtensionPosition(gpos: scala.reflect.internal.util.Position) { + def matches(mpos: m.Position): Boolean = { + gpos.isDefined && + gpos.start == mpos.start.offset && + gpos.end == mpos.end.offset + } + } /** Returns a map from byte offset to type name at that offset. */ private def offsetToType(gtree: g.Tree, @@ -53,7 +62,7 @@ trait NscSemanticApi extends ReflectToolkit { val parsed = dialect(gtree.toString()).parse[m.Type] parsed match { - case m.Parsed.Success(ast) => + case m.parsers.Parsed.Success(ast) => builder(gtree.pos.point) = cleanUp(ast) case _ => } @@ -97,6 +106,23 @@ trait NscSemanticApi extends ReflectToolkit { builder } + private def collect[T](gtree: g.Tree)( + pf: PartialFunction[g.Tree, T]): Seq[T] = { + val builder = Seq.newBuilder[T] + val f = pf.lift + def iter(gtree: g.Tree): Unit = { + f(gtree).foreach(builder += _) + gtree match { + case t @ g.TypeTree() if t.original != null && t.original.nonEmpty => + iter(t.original) + case _ => + gtree.children.foreach(iter) + } + } + iter(gtree) + builder.result() + } + private def getSemanticApi(unit: g.CompilationUnit, config: ScalafixConfig): SemanticApi = { val offsets = offsetToType(unit.body, config.dialect) @@ -111,6 +137,21 @@ trait NscSemanticApi extends ReflectToolkit { None } } + + override def desugared[T <: Tree](tree: T)( + implicit parse: Parse[T]): Option[T] = { + val result = collect[Option[T]](unit.body) { +// case t if { logger.elem(t.toString(), g.showRaw(t)); false } => None + case t if t.pos.matches(tree.pos) => + import scala.meta._ + parse(m.Input.String(t.toString()), config.dialect) match { + case m.parsers.Parsed.Success(x) => Some(x) + case _ => None + } + }.flatten +// logger.elem(result) + result.headOption + } } } diff --git a/scalafix-nsc/src/test/scala/cats/data/Xor.scala b/scalafix-nsc/src/test/scala/cats/data/Xor.scala new file mode 100644 index 000000000..8e1e16cb3 --- /dev/null +++ b/scalafix-nsc/src/test/scala/cats/data/Xor.scala @@ -0,0 +1,17 @@ +package cats.data +import scala.language.higherKinds + +sealed abstract class Xor[+A, +B] extends Product with Serializable { + def map[C](f: B => C) = ??? +} + +object Xor { + def left[A, B](a: A): A Xor B = Xor.Left(a) + def right[A, B](b: B): A Xor B = Xor.Right(b) + final case class Left[+A](a: A) extends (A Xor Nothing) + final case class Right[+B](b: B) extends (Nothing Xor B) +} + +sealed abstract class XorT[F[_], A, B](value: F[A Xor B]) + +sealed abstract class EitherT[F[_], A, B](value: F[Either[A, B]]) diff --git a/scalafix-nsc/src/test/scala/cats/implicits/package.scala b/scalafix-nsc/src/test/scala/cats/implicits/package.scala new file mode 100644 index 000000000..2cb979e2f --- /dev/null +++ b/scalafix-nsc/src/test/scala/cats/implicits/package.scala @@ -0,0 +1,9 @@ +package cats + +import scala.language.implicitConversions + +package object implicits { + implicit class EitherOps[A, B](from: Either[A, B]) { + def map[C](f: B => C): Either[A, C] = ??? + } +}