Skip to content

Commit

Permalink
Use scalatest Runner.discoveredSuites scalatest/scalatest#2319, remov…
Browse files Browse the repository at this point in the history
…e all custom discovery
  • Loading branch information
neko-kai committed May 24, 2024
1 parent 83d0d88 commit 0852051
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 116 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.scalatest.distage

import _root_.distage.TagK
import io.github.classgraph.{ClassGraph, ClassInfo}
import izumi.distage.modules.DefaultModule
import izumi.distage.testkit.DebugProperties
import izumi.distage.testkit.model.{DistageTest, SuiteId}
Expand All @@ -11,16 +10,11 @@ import izumi.distage.testkit.services.scalatest.dstest.DistageTestsRegistrySingl
import izumi.distage.testkit.services.scalatest.dstest.{DistageTestsRegistrySingleton, SafeTestReporter}
import izumi.distage.testkit.spec.AbstractDistageSpec
import izumi.fundamentals.platform.console.TrivialLogger
import izumi.fundamentals.platform.functional.Identity
import izumi.fundamentals.platform.jvm.IzClasspath
import org.scalatest.*
import org.scalatest.exceptions.{DuplicateTestNameException, TestCanceledException}
import org.scalatest.tools.Runner

import java.nio.file.Paths
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.Try
import scala.util.chaining.scalaUtilChainingOps
import scala.util.control.NonFatal

trait ScalatestInitWorkaround {
def awaitTestsLoaded(): Unit
Expand All @@ -39,40 +33,22 @@ object ScalatestInitWorkaround {
private val classpathScanned = new AtomicBoolean(false)
private val latch = new java.util.concurrent.CountDownLatch(1)

import scala.jdk.CollectionConverters.*

def awaitTestsLoaded(): Unit = {
latch.await()
}

def doScan[F[_]](instance: DistageScalatestTestSuiteRunner[F]): Unit = {
if (classpathScanned.compareAndSet(false, true)) {
val classLoader = instance.getClass.getClassLoader
val scan = new ClassGraph()
.enableClassInfo()
.addClassLoader(classLoader)
.pipe(instance.modifyClasspathScan)
.scan()
try {
val suiteClassName = classOf[DistageScalatestTestSuiteRunner[Identity]].getName

val allTestClasses = scan.getSubclasses(suiteClassName).asScala.filterNot(_.isAbstract)
val onlyTestClassesInCurrentModule = allTestClasses.filter(instance._sbtIsClassDefinedInCurrentTestModule(classLoader))

lazy val debugLogger = TrivialLogger.make[ScalatestInitWorkaroundImpl.type](DebugProperties.`izumi.distage.testkit.debug`.name)
onlyTestClassesInCurrentModule.foreach(
classInfo =>
Try {
debugLogger.log(s"Added scanned class `${classInfo.getName}` to current test run")
classInfo.loadClass().getDeclaredConstructor().newInstance()
}
)

DistageTestsRegistrySingleton.disableRegistration()
latch.countDown()
} finally {
scan.close()
val classNames = Runner.discoveredSuites.getOrElse(Set.empty)
val curName = instance.getClass.getName
if (classNames.nonEmpty && classNames != Set(curName)) {
classNames.foreach {
Class.forName(_).getDeclaredConstructor().newInstance()
}
}

DistageTestsRegistrySingleton.disableRegistration()
latch.countDown()
}
}
}
Expand All @@ -85,72 +61,6 @@ abstract class DistageScalatestTestSuiteRunner[F[_]](
) extends TestSuite
with AbstractDistageSpec[F] {

/**
* Modify test discovery options for SBT test runner only.
* Overriding this with [[withWhitelistJarsOnly]] will slightly boost test start-up speed,
* but will disable the ability to discover tests that inherit [[izumi.distage.testkit.services.scalatest.dstest.DistageAbstractScalatestSpec]]
* indirectly through a different library JAR. (this does not affect local sbt modules)
*/
def modifyClasspathScan: ClassGraph => ClassGraph = identity
protected final def withWhitelistJarsOnly: ClassGraph => ClassGraph = _.acceptJars("distage-testkit-scalatest*")

/**
* Override this to change the heuristic by which testkit determines that a test class is defined in the current SBT module.
*
* Affects SBT test runner only.
*
* By default we assume that classes with classfiles located in the first directory on the classpath
* which contains `test-classes` in its pathname are the classes defined in the current SBT test module.
*
* @see [[_sbtFindCurrentTestModuleClasspathElement]] - override this to change just the method for finding the `test-classes` directory not all the logic
*/
def _sbtIsClassDefinedInCurrentTestModule(classLoader: ClassLoader): ClassInfo => Boolean = {
val classpathElems = IzClasspath.safeClasspathSeq(classLoader)
_sbtFindCurrentTestModuleClasspathElement(classpathElems) match {
case Some(firstTestClassesDir) =>
val firstTestClassesDirPath = Paths.get(firstTestClassesDir).toAbsolutePath.toRealPath().normalize()
(classInfo: ClassInfo) => {
try {
val filePath = classInfo.getClasspathElementFile.toPath.toAbsolutePath.toRealPath().normalize()
val isInThisModuleTestClassesDir = firstTestClassesDirPath == filePath || filePath.startsWith(firstTestClassesDirPath)
if (!isInThisModuleTestClassesDir) {
_sbtReportFilteredOutTest(classInfo, filePath.toString, firstTestClassesDirPath.toString)
}
isInThisModuleTestClassesDir
} catch {
case NonFatal(t) =>
import izumi.fundamentals.platform.exceptions.IzThrowable.*
System.err.println(
s"DISTAGE-TESTKIT CRITICAL: Couldn't determine if a test class className=`${classInfo.getName}` was defined in the current SBT module due to error=${t.stacktraceString}" +
" including it by default"
)
true
}
}
case None =>
import izumi.fundamentals.platform.strings.IzString.*
System.err.println(
s"""DISTAGE-TESTKIT CRITICAL: Couldn't find a `test-classes` directory on the classpath, disabling fix preventing launch of tests defined in other sbt modules.
|Classpath was = ${classpathElems.niceList()}""".stripMargin
)
_ => true
}
}

/** Override this to change the method for finding the `test-classes` directory for [[_sbtIsClassDefinedInCurrentTestModule]] */
protected def _sbtFindCurrentTestModuleClasspathElement(classpathElems: Seq[String]): Option[String] = {
val firstTestClassesDir = classpathElems.find(elt => elt.contains("test-classes") && Try(Paths.get(elt).toFile.isDirectory).getOrElse(false))
firstTestClassesDir
}

/** Override this to change or remove log message warning about a filtered out test class in [[_sbtIsClassDefinedInCurrentTestModule]] */
protected def _sbtReportFilteredOutTest(cls: ClassInfo, fileClassPathElem: String, firstTestClassesClassPathElem: String): Unit = {
System.out.println(
s"DISTAGE-TESTKIT: Filtered out test class className=`${cls.getName}` because it was not defined in current SBT module."
+ s" Expected classpath element: expected=`$firstTestClassesClassPathElem` but got actual=`$fileClassPathElem`"
)
}

// initialize status early, so that runner can set it to `true` even before this test is discovered
// by scalatest, if it was already executed by that time
private[this] val status: StatefulStatus = DistageTestsRegistrySingleton.registerStatus[F](suiteId)
Expand All @@ -162,18 +72,8 @@ abstract class DistageScalatestTestSuiteRunner[F[_]](

override def run(testName: Option[String], args: Args): Status = {
DistageTestsRegistrySingleton.registerSuiteReporter(suiteId)(SuiteReporter(args.tracker, args.reporter))
// If, we're running under sbt, scan the classpath manually to add all tests
// in the classloader before starting anything, because sbt runner
// instantiates & runs tests at the same time, so when `run` is called
// NOT all tests have been registered, so we must force all tests, otherwise
// we can't be sure.
//
// NON-sbt ScalatestRunner first instantiates ALL tests, THEN calls `.run` method,
// so for non-sbt runs we KNOW that all tests have already been registered, so we
// don't have to scan the classpath ourselves.
if (args.reporter.getClass.getName.contains("org.scalatest.tools.Framework")) {
ScalatestInitWorkaround.scan(this).awaitTestsLoaded()
}

ScalatestInitWorkaround.scan(this).awaitTestsLoaded()

try {
DistageTestsRegistrySingleton.proceedWithTests[F]() match {
Expand All @@ -198,9 +98,10 @@ abstract class DistageScalatestTestSuiteRunner[F[_]](
val testsInThisTestClass = DistageTestsRegistrySingleton.registeredTests[F].filter(_.meta.test.id.suite == SuiteId(suiteId))
val testsByName = testsInThisTestClass.groupBy(_.meta.test.id.name)
testsByName.foreach {
case (testName, tests) => if (tests.size > 1) {
throw new DuplicateTestNameException(testName, 0)
}
case (testName, tests) =>
if (tests.size > 1) {
throw new DuplicateTestNameException(testName, 0)
}
}
testsByName.keys.toSet
}
Expand Down
2 changes: 1 addition & 1 deletion project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ object V {

val kind_projector = "0.13.3"

val scalatest = "3.2.18"
val scalatest = "3.3.0-x-SNAPSHOT"

val cats = "2.10.0"
val cats_effect = "3.5.4"
Expand Down

0 comments on commit 0852051

Please sign in to comment.