Skip to content

Commit

Permalink
cats.data.Xor/XorT to Either/EitherT rewrite (+7 squashed commits)
Browse files Browse the repository at this point in the history
Squashed commits:
[b5aaec5] Pushing a few experiments for feedback.
[445b2cd] Experiment using desugared to rewrite xor
[13935bf] WIP
[40caaa3] Pick up configuration in sbt plugin.
[ab34355] Use build-info for friendlier intellij setup
[a5b666a] Setup hocon configuration.
[c94fef1] Cross-build to 2.12 (scalacenter#28)

* Upgrade scalameta and cross-build to 2.12

* Gracefully handle 2.11 and 2.12 in sbt plugin.

* Don't rely on sbt.ivy.home.
  • Loading branch information
olafurpg authored and ShaneDelmore committed Dec 26, 2016
1 parent e2f97d6 commit fca3e03
Show file tree
Hide file tree
Showing 14 changed files with 430 additions and 4 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/scalafix/rewrite/Rewrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ object Rewrite {
}

val syntaxRewrites: Seq[Rewrite] = Seq(ProcedureSyntax, VolatileLazyVal)
val semanticRewrites: Seq[Rewrite] = Seq(ExplicitImplicit)
val semanticRewrites: Seq[Rewrite] = Seq(ExplicitImplicit, Xor2Either)
val allRewrites: Seq[Rewrite] = syntaxRewrites ++ semanticRewrites
val defaultRewrites: Seq[Rewrite] =
allRewrites.filterNot(_ == VolatileLazyVal)
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/scalafix/rewrite/SemanticApi.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package scalafix.rewrite

import scala.meta.Defn
import scala.meta.Type
import scala.meta.{Defn, Term, Tree, Type}
import scala.meta.parsers.Parse

/** A custom semantic api for scalafix rewrites.
*
Expand All @@ -18,4 +18,13 @@ trait SemanticApi {

/** Returns the type annotation for given val/def. */
def typeSignature(defn: Defn): Option[Type]

def desugared[A <: Tree](tree: A)(implicit parse: Parse[A]): Option[A]

class Desugared[T <: Tree: Parse] {
def unapply(original: T): Option[T] = desugared(original)
}

object DType extends Desugared[Type]
object DTerm extends Desugared[Term]
}
53 changes: 53 additions & 0 deletions core/src/main/scala/scalafix/rewrite/Xor2Either.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package scalafix.rewrite

import scala.collection.immutable.Seq
import scala.{meta => m}
import scalafix.util._
import scala.meta._

case object Xor2Either extends Rewrite {
override def rewrite(ast: m.Tree, ctx: RewriteCtx): Seq[Patch] = {
implicit val semanticApi: SemanticApi = getSemanticApi(ctx)

//Create a sequence of type replacements
val replacementTypes = List(
ReplaceType(t"cats.data.XorT", t"cats.data.EitherT", "EitherT"),
ReplaceType(t"cats.data.Xor", t"scala.util.Either", "Either"),
ReplaceType(t"cats.data.Xor.Left", t"scala.util.Left", "Left"),
ReplaceType(t"cats.data.Xor.Right", t"scala.util.Either.Right", "Right")
)

//Add in some method replacements
val replacementTerms = List(
ReplaceTerm(q"cats.data.Xor.Right.apply",
q"scala.util.Right.apply",
q"scala.util"),
ReplaceTerm(q"cats.data.Xor.Left.apply",
q"scala.util.Left.apply",
q"scala.util")
)

//Then add needed imports.
//todo - derive this from patches created, types
//and terms replaced
//Only add if they are not already imported
val additionalImports = List(
"cats.data.EitherT",
"cats.implicits._",
"scala.util.Either"
)

val typeReplacements =
new ChangeType(ast).gatherPatches(replacementTypes)

val termReplacements =
new ChangeMethod(ast).gatherPatches(replacementTerms)

//Make this additional imports, beyond what can be derived from the types
val addedImports =
if (typeReplacements.isEmpty && termReplacements.isEmpty) Seq[Patch]()
else new AddImport(ast).gatherPatches(additionalImports)

addedImports ++ typeReplacements ++ termReplacements
}
}
37 changes: 37 additions & 0 deletions core/src/main/scala/scalafix/util/AddImport.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package scalafix.util

import scala.collection.immutable.Seq
import scala.{meta => m}
import scala.meta._
import scalafix.rewrite._

class AddImport(ast: m.Tree)(implicit sApi: SemanticApi) {
val allImports = ast.collect {
case t @ q"import ..$importersnel" => t -> importersnel
}

val firstImport = allImports.headOption
val firstImportFirstToken = firstImport.flatMap {
case (importStatement, _) => importStatement.tokens.headOption
}
val tokenBeforeFirstImport = firstImportFirstToken.flatMap { stopAt =>
ast.tokens.takeWhile(_ != stopAt).lastOption
}

//This is currently a very dumb implementation.
//It does no checking for existing imports and makes
//no attempt to consolidate imports
def addedImports(importString: String): Seq[Patch] =
tokenBeforeFirstImport
.map(
beginImportsLocation =>
Patch
.insertAfter(beginImportsLocation, importString)
)
.toList

def gatherPatches(imports: Seq[String]): Seq[Patch] = {
val importStrings = imports.map("import " + _).mkString("\n", "\n", "\n")
addedImports(importStrings)
}
}
45 changes: 45 additions & 0 deletions core/src/main/scala/scalafix/util/AnyDiff.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package scalafix.util

import scala.collection.immutable.Seq
import scala.meta.Tree

/** Helper class to create textual diff between two objects */
case class AnyDiff(a: Any, b: Any) extends Exception {
override def toString: String = s"""$a != $b $mismatchClass"""
def detailed: String = compare(a, b)

/** Best effort attempt to find a line number for scala.meta.Tree */
def lineNumber: Int =
1 + (a match {
case e: Tree => e.pos.start.line
case Some(t: Tree) => t.pos.start.line
case lst: Seq[_] =>
lst match {
case (head: Tree) :: tail => head.pos.start.line
case _ => -2
}
case _ => -2
})
def mismatchClass: String =
if (clsName(a) != clsName(b)) s"(${clsName(a)} != ${clsName(b)})"
else s"same class ${clsName(a)}"

private def clsName(a: Any) = a.getClass.getName

private def compare(a: Any, b: Any): String =
(a, b) match {
case (t1: Tree, t2: Tree) =>
s"""$toString
|Syntax diff:
|${t1.syntax}
|${t2.syntax}
|
|Structure diff:
|${t1.structure}
|${t2.structure}
""".stripMargin
case (t1: Seq[_], t2: Seq[_]) =>
t1.zip(t2).map { case (a, b) => compare(a, b) }.mkString
case _ => toString
}
}
53 changes: 53 additions & 0 deletions core/src/main/scala/scalafix/util/ChangeMethod.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package scalafix.util

import scala.collection.immutable.Seq
import scala.{meta => m}
import scala.meta._
import scalafix.rewrite._
import syntax._

case class ReplaceTerm(
original: m.Term,
replacement: m.Term,
importSegment: m.Term
) {
require(replacement.syntax.contains(importSegment.syntax),
"Specify the portion of the new Term to be imported")
}

class ChangeMethod(ast: m.Tree)(implicit sApi: SemanticApi) {

import sApi._

def partialTermMatch(rt: ReplaceTerm)
: PartialFunction[m.Tree, (scala.meta.Term, ReplaceTerm)] = {
// only desugar selections
case t @ q"$expr.$name" & DTerm(desugared)
if desugared.termNames.map(_.syntax) == rt.original.termNames.map(
_.syntax) =>
//I am not sure how to get rid of this asInstanceOf
//But it should be safe because I have matched on DTerm above
t.asInstanceOf[m.Term] -> rt
}

def partialTermMatches(replacementTerms: Seq[ReplaceTerm])
: PartialFunction[m.Tree, (m.Term, ReplaceTerm)] =
replacementTerms.map(partialTermMatch).reduce(_ orElse _)

def terms(ptm: PartialFunction[m.Tree, (m.Term, ReplaceTerm)]) =
ast.collect(ptm)

def termReplacements(trms: Seq[(m.Term, ReplaceTerm)]): Seq[Patch] =
trms.map {
case (t, ReplaceTerm(oldTerm, newTerm, imported)) =>
val replacement = newTerm.termNames
.map(_.syntax)
.diff(imported.termNames.map(_.syntax)) //Strip off imported portion
.filterNot(_ == "apply") //special handling for apply, suppress as it will be inferred
.mkString(".")
Patch.replace(t, replacement)
}

def gatherPatches(tr: Seq[ReplaceTerm]): Seq[Patch] =
termReplacements(terms(partialTermMatches(tr)))
}
39 changes: 39 additions & 0 deletions core/src/main/scala/scalafix/util/ChangeType.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package scalafix.util

import scala.collection.immutable.Seq
import scala.{meta => m}
import scalafix.rewrite._

//Provide a little structure to the replacements we will be performing
case class ReplaceType(original: m.Type,
replacement: m.Type,
newString: String) {
def toPatch(t: m.Type): Patch = Patch.replace(t, newString)
}

class ChangeType(ast: m.Tree)(implicit sApi: SemanticApi) {
import sApi._

def partialTypeMatch(
rt: ReplaceType): PartialFunction[m.Tree, (m.Type, ReplaceType)] = {
case t @ DType(desugared)
if StructurallyEqual(desugared, rt.original).isRight =>
t -> rt
}

def partialTypeMatches(replacementTypes: Seq[ReplaceType])
: PartialFunction[m.Tree, (m.Type, ReplaceType)] =
replacementTypes.map(partialTypeMatch).reduce(_ orElse _)

def tpes(ptm: PartialFunction[m.Tree, (m.Type, ReplaceType)])
: Seq[(m.Type, ReplaceType)] = ast.collect { ptm }

//This is unsafe, come up with something better
def typeReplacements(tpes: Seq[(m.Type, ReplaceType)]): Seq[Patch] =
tpes.map {
case (tree, rt) => rt.toPatch(tree)
}

def gatherPatches(tr: Seq[ReplaceType]): Seq[Patch] =
typeReplacements(tpes(partialTypeMatches(tr)))
}
20 changes: 20 additions & 0 deletions core/src/main/scala/scalafix/util/Patch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,24 @@ object Patch {
.map(_.syntax)
.mkString("")
}

def replace(token: Token, replacement: String): Patch =
Patch(token, token, replacement)

def replace(tree: Tree, replacement: String): Patch =
Patch(tree.tokens.head, tree.tokens.last, replacement)

def insertBefore(token: Token, toPrepend: String) =
replace(token, s"$toPrepend${token.syntax}")

def insertBefore(tree: Tree, toPrepend: String): Patch =
replace(tree, s"$toPrepend${tree.syntax}")

def insertAfter(token: Token, toAppend: String) =
replace(token, s"$toAppend${token.syntax}")

def insertAfter(tree: Tree, toAppend: String): Patch =
replace(tree, s"${tree.syntax}$toAppend")

def delete(tree: Tree): Patch = replace(tree, "")
}
42 changes: 42 additions & 0 deletions core/src/main/scala/scalafix/util/StructurallyEqual.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package scalafix.util

import scala.collection.immutable.Seq

object StructurallyEqual {
import scala.meta.Tree

/** Test if two trees are structurally equal.
* @return Left(errorMessage with minimal diff) if trees are not structurally
* different, otherwise Right(Unit). To convert into exception with
* meaningful error message,
* val Right(_) = StructurallyEqual(a, b)
**/
def apply(a: Tree, b: Tree): Either[AnyDiff, Unit] = {
def loop(x: Any, y: Any): Boolean = {
val ok: Boolean = (x, y) match {
case (x, y) if x == null || y == null => x == null && y == null
case (x: Some[_], y: Some[_]) => loop(x.get, y.get)
case (x: None.type, y: None.type) => true
case (xs: Seq[_], ys: Seq[_]) =>
xs.length == ys.length &&
xs.zip(ys).forall {
case (x, y) => loop(x, y)
}
case (x: Tree, y: Tree) =>
def sameStructure =
x.productPrefix == y.productPrefix &&
loop(x.productIterator.toList, y.productIterator.toList)
sameStructure
case _ => x == y
}
if (!ok) throw AnyDiff(x, y)
else true
}
try {
loop(a, b)
Right(Unit)
} catch {
case t: AnyDiff => Left(t)
}
}
}
25 changes: 25 additions & 0 deletions core/src/main/scala/scalafix/util/syntax/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package scalafix.util

import scala.{meta => m}
import scala.meta._

package object syntax {

//Allow two patterns to be combined. Contributed by @nafg
object & { def unapply[A](a: A) = Some((a, a)) }

implicit class MetaOps(from: m.Tree) {
def termNames: List[Term.Name] = {
from collect {
case t: Term.Name => t
}
}

def typeNames: List[Type.Name] = {
from collect {
case t: Type.Name => t
}
}
}

}
36 changes: 36 additions & 0 deletions core/src/test/resources/Xor/basic.source
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
rewrites = [Xor2Either]
<<< xor 1
import scala.concurrent.Future
import cats.data.{ Xor, XorT }
trait A {
type MyDisjunction = Xor[Int, String]
val r: MyDisjunction = Xor.Right.apply("")
val s: Xor[Int, String] = cats.data.Xor.Left(1 /* comment */)
val t: Xor[Int, String] = r.map(_ + "!")
val nest: Seq[Xor[Int, cats.data.Xor[String, Int]]]
val u: XorT[Future, Int, String] = ???
}
>>>
import cats.data.EitherT
import cats.implicits._
import scala.util.Either
import scala.concurrent.Future
import cats.data.{ Xor, XorT }
trait A {
type MyDisjunction = Either[Int, String]
val r: MyDisjunction = Right("")
val s: Either[Int, String] = Left(1 /* comment */)
val t: Either[Int, String] = r.map(_ + "!")
val nest: Seq[Either[Int, Either[String, Int]]]
val u: EitherT[Future, Int, String] = ???
}
<<< xor do not modify when not present
import scala.concurrent.Future
trait A {
val num = 1
}
>>>
import scala.concurrent.Future
trait A {
val num = 1
}
Loading

0 comments on commit fca3e03

Please sign in to comment.