forked from scala-native/scala-native
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ScalaNativeJUnitPlugin.scala
308 lines (245 loc) · 11 KB
/
ScalaNativeJUnitPlugin.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
package scala.scalanative.junit.plugin
// Ported from Scala.js
import scala.annotation.tailrec
import scala.tools.nsc._
import scala.tools.nsc.plugins.{
Plugin => NscPlugin,
PluginComponent => NscPluginComponent
}
/** The Scala Native JUnit plugin replaces reflection based test lookup.
*
* For each JUnit test `my.pkg.X`, it generates a bootstrapper module/object
* `my.pkg.X\$scalanative\$junit\$bootstrapper` implementing
* `scala.scalanative.junit.Bootstrapper`.
*
* The test runner uses these objects to obtain test metadata and dispatch to
* relevant methods.
*/
class ScalaNativeJUnitPlugin(val global: Global) extends NscPlugin {
val name: String = "Scala Native JUnit plugin"
val components: List[NscPluginComponent] =
List(ScalaNativeJUnitPluginComponent)
val description: String = "Makes JUnit test classes invokable in Scala Native"
object ScalaNativeJUnitPluginComponent
extends plugins.PluginComponent
with transform.Transform {
val global: Global = ScalaNativeJUnitPlugin.this.global
import global._
import definitions._
import rootMirror.getRequiredClass
val phaseName: String = "junit-inject"
val runsAfter: List[String] = List("mixin")
override val runsBefore: List[String] = List("nir")
protected def newTransformer(unit: CompilationUnit): Transformer =
new ScalaNativeJUnitPluginTransformer
private object JUnitAnnots {
val Test: ClassSymbol = getRequiredClass("org.junit.Test")
val Before: ClassSymbol = getRequiredClass("org.junit.Before")
val After: ClassSymbol = getRequiredClass("org.junit.After")
val BeforeClass: ClassSymbol = getRequiredClass("org.junit.BeforeClass")
val AfterClass: ClassSymbol = getRequiredClass("org.junit.AfterClass")
val Ignore: ClassSymbol = getRequiredClass("org.junit.Ignore")
}
private object Names {
val beforeClass: TermName = newTermName("beforeClass")
val afterClass: TermName = newTermName("afterClass")
val before: TermName = newTermName("before")
val after: TermName = newTermName("after")
val testMetadata: TermName = newTermName("testMetadata")
val tests: TermName = newTermName("tests")
val invokeTest: TermName = newTermName("invokeTest")
val newInstance: TermName = newTermName("newInstance")
val instance: TermName = newTermName("instance")
val name: TermName = newTermName("name")
}
private lazy val BootstrapperClass =
getRequiredClass("scala.scalanative.junit.Bootstrapper")
private lazy val TestClassMetadataClass =
getRequiredClass("scala.scalanative.junit.TestClassMetadata")
private lazy val TestMetadataClass =
getRequiredClass("scala.scalanative.junit.TestMetadata")
private lazy val FutureClass =
getRequiredClass("scala.concurrent.Future")
private lazy val FutureModule_successful =
getMemberMethod(FutureClass.companionModule, newTermName("successful"))
private lazy val SuccessModule_apply =
getMemberMethod(getRequiredClass("scala.util.Success").companionModule,
nme.apply)
class ScalaNativeJUnitPluginTransformer extends Transformer {
override def transform(tree: Tree): Tree = tree match {
case tree: PackageDef =>
@tailrec
def hasTests(sym: Symbol): Boolean = {
sym.info.members.exists(m =>
m.isMethod && m.hasAnnotation(JUnitAnnots.Test)) ||
sym.superClass.exists && hasTests(sym.superClass)
}
def isTest(sym: Symbol) = {
sym.isClass &&
!sym.isModuleClass &&
!sym.isAbstract &&
!sym.isTrait &&
hasTests(sym)
}
val bootstrappers = tree.stats.collect {
case clDef: ClassDef if isTest(clDef.symbol) =>
genBootstrapper(clDef.symbol.asClass)
}
val newStats = tree.stats.map(transform) ++ bootstrappers
treeCopy.PackageDef(tree, tree.pid, newStats)
case tree =>
super.transform(tree)
}
def genBootstrapper(testClass: ClassSymbol): ClassDef = {
// Create the module and its module class, and enter them in their owner's scope
val (moduleSym, bootSym) = testClass.owner.newModuleAndClassSymbol(
newTypeName(
testClass.name.toString + "$scalanative$junit$bootstrapper"),
testClass.pos,
0L)
val bootInfo =
ClassInfoType(List(ObjectTpe, BootstrapperClass.toType),
newScope,
bootSym)
bootSym.setInfo(bootInfo)
moduleSym.setInfoAndEnter(bootSym.toTypeConstructor)
bootSym.owner.info.decls.enter(bootSym)
val testMethods = annotatedMethods(testClass, JUnitAnnots.Test)
val defs = List(
genConstructor(bootSym),
genCallOnModule(bootSym,
Names.beforeClass,
testClass.companionModule,
JUnitAnnots.BeforeClass),
genCallOnModule(bootSym,
Names.afterClass,
testClass.companionModule,
JUnitAnnots.AfterClass),
genCallOnParam(bootSym, Names.before, testClass, JUnitAnnots.Before),
genCallOnParam(bootSym, Names.after, testClass, JUnitAnnots.After),
genTestMetadata(bootSym, testClass),
genTests(bootSym, testMethods),
genInvokeTest(bootSym, testClass, testMethods),
genNewInstance(bootSym, testClass)
)
ClassDef(bootSym, defs)
}
private def genConstructor(owner: ClassSymbol): DefDef = {
/* The constructor body must be a Block in order not to freak out the
* JVM back-end.
*/
val rhs = Block(
gen.mkMethodCall(Super(owner, tpnme.EMPTY),
ObjectClass.primaryConstructor,
Nil,
Nil))
val sym = owner.newClassConstructor(NoPosition)
sym.setInfoAndEnter(MethodType(Nil, owner.tpe))
typer.typedDefDef(newDefDef(sym, rhs)())
}
private def genCallOnModule(owner: ClassSymbol,
name: TermName,
module: Symbol,
annot: Symbol): DefDef = {
val sym = owner.newMethodSymbol(name)
sym.setInfoAndEnter(MethodType(Nil, definitions.UnitTpe))
val calls = annotatedMethods(module, annot)
.map(gen.mkMethodCall(Ident(module), _, Nil, Nil))
.toList
typer.typedDefDef(newDefDef(sym, Block(calls: _*))())
}
private def genCallOnParam(owner: ClassSymbol,
name: TermName,
testClass: Symbol,
annot: Symbol): DefDef = {
val sym = owner.newMethodSymbol(name)
val instanceParam =
sym.newValueParameter(Names.instance).setInfo(ObjectTpe)
sym.setInfoAndEnter(
MethodType(List(instanceParam), definitions.UnitTpe))
val instance = castParam(instanceParam, testClass)
val calls = annotatedMethods(testClass, annot)
.map(gen.mkMethodCall(instance, _, Nil, Nil))
.toList
typer.typedDefDef(newDefDef(sym, Block(calls: _*))())
}
private def genTestMetadata(owner: ClassSymbol,
testClass: ClassSymbol): DefDef = {
val sym = owner.newMethodSymbol(Names.testMetadata)
sym.setInfoAndEnter(
MethodType(Nil, typeRef(NoType, TestClassMetadataClass, Nil))
)
val ignored = testClass.hasAnnotation(JUnitAnnots.Ignore)
val isIgnored = Literal(Constant(ignored))
val rhs = New(TestClassMetadataClass, isIgnored)
typer.typedDefDef(newDefDef(sym, rhs)())
}
private def genTests(owner: ClassSymbol, tests: Scope): DefDef = {
val sym = owner.newMethodSymbol(Names.tests)
sym.setInfoAndEnter(
MethodType(Nil,
typeRef(NoType, ArrayClass, List(TestMetadataClass.tpe))))
val metadata = for (test <- tests) yield {
val reifiedAnnot = New(
JUnitAnnots.Test,
test.getAnnotation(JUnitAnnots.Test).get.args: _*)
val name = Literal(Constant(test.name.toString))
val testIgnored = test.hasAnnotation(JUnitAnnots.Ignore)
val isIgnored = Literal(Constant(testIgnored))
New(TestMetadataClass, name, isIgnored, reifiedAnnot)
}
val rhs = ArrayValue(TypeTree(TestMetadataClass.tpe), metadata.toList)
typer.typedDefDef(newDefDef(sym, rhs)())
}
private def genInvokeTest(owner: ClassSymbol,
testClass: Symbol,
tests: Scope): DefDef = {
val sym = owner.newMethodSymbol(Names.invokeTest)
val instanceParam =
sym.newValueParameter(Names.instance).setInfo(ObjectTpe)
val nameParam = sym.newValueParameter(Names.name).setInfo(StringTpe)
sym.setInfo(
MethodType(List(instanceParam, nameParam),
FutureClass.toTypeConstructor))
val instance = castParam(instanceParam, testClass)
val rhs = tests.foldRight[Tree] {
Throw(New(typeOf[NoSuchMethodException], Ident(nameParam)))
} { (sym, next) =>
val cond =
gen.mkMethodCall(Ident(nameParam),
Object_equals,
Nil,
List(Literal(Constant(sym.name.toString))))
val call = genTestInvocation(sym, instance)
If(cond, call, next)
}
typer.typedDefDef(newDefDef(sym, rhs)())
}
private def genTestInvocation(sym: Symbol, instance: Tree): Tree = {
sym.tpe.resultType.typeSymbol match {
case UnitClass =>
val boxedUnit = gen.mkAttributedRef(definitions.BoxedUnit_UNIT)
val newSuccess =
gen.mkMethodCall(SuccessModule_apply, List(boxedUnit))
Block(
gen.mkMethodCall(instance, sym, Nil, Nil),
gen.mkMethodCall(FutureModule_successful, List(newSuccess))
)
case _ =>
reporter.error(sym.pos, "JUnit test must have Unit return type")
EmptyTree
}
}
private def genNewInstance(owner: ClassSymbol,
testClass: ClassSymbol): DefDef = {
val sym = owner.newMethodSymbol(Names.newInstance)
sym.setInfoAndEnter(MethodType(Nil, ObjectTpe))
typer.typedDefDef(newDefDef(sym, New(testClass))())
}
private def castParam(param: Symbol, clazz: Symbol): Tree =
gen.mkAsInstanceOf(Ident(param), clazz.tpe, any = false)
private def annotatedMethods(owner: Symbol, annot: Symbol): Scope =
owner.info.members.filter(m => m.isMethod && m.hasAnnotation(annot))
}
}
}