Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TransactionalTransform implements ASTTransformation{
private static final String METHOD_EXECUTE = "execute"
private static final Set<String> METHOD_NAME_EXCLUDES = new HashSet<String>(Arrays.asList("afterPropertiesSet", "destroy"));
private static final Set<String> ANNOTATION_NAME_EXCLUDES = new HashSet<String>(Arrays.asList(PostConstruct.class.getName(), PreDestroy.class.getName(), Transactional.class.getName(), Rollback.class.getName(), "grails.web.controllers.ControllerMethod", NotTransactional.class.getName()));
private static final Set<String> JUNIT_ANNOTATION_NAMES = new HashSet<String>(Arrays.asList("org.junit.Before", "org.junit.After"));
private static final String SPEC_CLASS = "spock.lang.Specification";
public static final String PROPERTY_DATA_SOURCE = "datasource"

Expand Down Expand Up @@ -116,7 +117,8 @@ class TransactionalTransform implements ASTTransformation{
for (MethodNode md in methods) {
String methodName = md.getName()
int modifiers = md.modifiers
if (!md.isSynthetic() && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers) && !Modifier.isStatic(modifiers)) {
if (!md.isSynthetic() && Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers) &&
!Modifier.isStatic(modifiers) && !hasJunitAnnotation(md)) {
if(hasExcludedAnnotation(md)) continue

def startsWithSpock = methodName.startsWith('$spock')
Expand All @@ -138,7 +140,8 @@ class TransactionalTransform implements ASTTransformation{
if(hasAnnotation(md, DelegatingMethod.class)) continue
weaveTransactionalMethod(source, classNode, annotationNode, md);
}
else if(("setup".equals(methodName) || "cleanup".equals(methodName)) && isSpockTest(classNode)) {
else if ((("setup".equals(methodName) || "cleanup".equals(methodName)) && isSpockTest(classNode)) ||
hasJunitAnnotation(md)) {
def requiresNewTransaction = new AnnotationNode(annotationNode.classNode)
requiresNewTransaction.addMember("propagation", new PropertyExpression(new ClassExpression(ClassHelper.make(Propagation.class)), "REQUIRES_NEW"))
weaveTransactionalMethod(source, classNode, requiresNewTransaction, md, "execute")
Expand All @@ -161,6 +164,17 @@ class TransactionalTransform implements ASTTransformation{
excludedAnnotation
}

private boolean hasJunitAnnotation(MethodNode md) {
boolean excludedAnnotation = false;
for (AnnotationNode annotation : md.getAnnotations()) {
if(JUNIT_ANNOTATION_NAMES.contains(annotation.getClassNode().getName())) {
excludedAnnotation = true;
break;
}
}
excludedAnnotation
}

ClassNode getAnnotationClassNode(String annotationName) {
try {
final classLoader = Thread.currentThread().contextClassLoader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,46 @@ class TransactionalTransformSpec extends Specification {

}

void "Test @Rollback when applied to JUnit specifications"() {
when:
Class mySpec = new GroovyShell().evaluate('''
import grails.transaction.*
import org.junit.Test
import org.junit.Before
import org.junit.After

@Rollback
class MyJunitTest {
@Before
def junitSetup() {

}

@After
def junitCleanup() {

}

@Test
void junitTest() {
expect:
1 == 1
}
}
MyJunitTest
''')

then: "It implements TransactionManagerAware"
TransactionManagerAware.isAssignableFrom(mySpec)
mySpec.getDeclaredMethod('junitSetup')
mySpec.getDeclaredMethod('$tt__junitSetup', TransactionStatus)
mySpec.getDeclaredMethod('junitCleanup')
mySpec.getDeclaredMethod('$tt__junitCleanup', TransactionStatus)

mySpec.getDeclaredMethod('junitTest')
mySpec.getDeclaredMethod('$tt__junitTest', TransactionStatus)
}

void "Test @Rollback when applied to Spock specifications"() {
when: "A new instance of a class with a @Transactional method is created that subclasses another transactional class"
Class mySpec = new GroovyShell().evaluate('''
Expand Down