diff --git a/src/main/java/org/codehaus/groovy/control/customizers/SecureASTCustomizer.java b/src/main/java/org/codehaus/groovy/control/customizers/SecureASTCustomizer.java index 5a64348359a..6d7290d4426 100644 --- a/src/main/java/org/codehaus/groovy/control/customizers/SecureASTCustomizer.java +++ b/src/main/java/org/codehaus/groovy/control/customizers/SecureASTCustomizer.java @@ -611,7 +611,7 @@ public void call(final SourceUnit source, final GeneratorContext context, final } } - final SecuringCodeVisitor visitor = new SecuringCodeVisitor(); + final GroovyCodeVisitor visitor = createGroovyCodeVisitor(); ast.getStatementBlock().visit(visitor); for (ClassNode clNode : ast.getClasses()) { if (clNode!=classNode) { @@ -631,14 +631,18 @@ public void call(final SourceUnit source, final GeneratorContext context, final } } } - - private void checkMethodDefinitionAllowed(ClassNode owner) { + + protected GroovyCodeVisitor createGroovyCodeVisitor() { + return new SecuringCodeVisitor(); + } + + protected void checkMethodDefinitionAllowed(ClassNode owner) { if (isMethodDefinitionAllowed) return; List methods = filterMethods(owner); if (!methods.isEmpty()) throw new SecurityException("Method definitions are not allowed"); } - - private static List filterMethods(ClassNode owner) { + + protected static List filterMethods(ClassNode owner) { List result = new LinkedList(); List methods = owner.getMethods(); for (MethodNode method : methods) { @@ -650,7 +654,7 @@ private static List filterMethods(ClassNode owner) { return result; } - private void assertStarImportIsAllowed(final String packageName) { + protected void assertStarImportIsAllowed(final String packageName) { if (starImportsWhitelist != null && !starImportsWhitelist.contains(packageName)) { throw new SecurityException("Importing [" + packageName + "] is not allowed"); } @@ -659,7 +663,7 @@ private void assertStarImportIsAllowed(final String packageName) { } } - private void assertImportIsAllowed(final String className) { + protected void assertImportIsAllowed(final String className) { if (importsWhitelist != null && !importsWhitelist.contains(className)) { if (starImportsWhitelist != null) { // we should now check if the import is in the star imports @@ -685,7 +689,7 @@ private void assertImportIsAllowed(final String className) { } } - private void assertStaticImportIsAllowed(final String member, final String className) { + protected void assertStaticImportIsAllowed(final String member, final String className) { final String fqn = member.equals(className) ? member : className + "." + member; if (staticImportsWhitelist != null && !staticImportsWhitelist.contains(fqn)) { if (staticStarImportsWhitelist != null) { @@ -713,7 +717,7 @@ private void assertStaticImportIsAllowed(final String member, final String class * CodeVisitorSupport} class to make sure that future features of the language gets managed by this visitor. Thus, * adding a new feature would result in a compilation error if this visitor is not updated. */ - private class SecuringCodeVisitor implements GroovyCodeVisitor { + protected class SecuringCodeVisitor implements GroovyCodeVisitor { /** * Checks that a given statement is either in the whitelist or not in the blacklist. @@ -721,7 +725,7 @@ private class SecuringCodeVisitor implements GroovyCodeVisitor { * @param statement the statement to be checked * @throws SecurityException if usage of this statement class is forbidden */ - private void assertStatementAuthorized(final Statement statement) throws SecurityException { + protected void assertStatementAuthorized(final Statement statement) throws SecurityException { final Class clazz = statement.getClass(); if (statementsBlacklist != null && statementsBlacklist.contains(clazz)) { throw new SecurityException(clazz.getSimpleName() + "s are not allowed"); @@ -741,7 +745,7 @@ private void assertStatementAuthorized(final Statement statement) throws Securit * @param expression the expression to be checked * @throws SecurityException if usage of this expression class is forbidden */ - private void assertExpressionAuthorized(final Expression expression) throws SecurityException { + protected void assertExpressionAuthorized(final Expression expression) throws SecurityException { final Class clazz = expression.getClass(); if (expressionsBlacklist != null && expressionsBlacklist.contains(clazz)) { throw new SecurityException(clazz.getSimpleName() + "s are not allowed: " + expression.getText()); @@ -780,7 +784,7 @@ private void assertExpressionAuthorized(final Expression expression) throws Secu } } - private ClassNode getExpressionType(ClassNode objectExpressionType) { + protected ClassNode getExpressionType(ClassNode objectExpressionType) { return objectExpressionType.isArray() ? getExpressionType(objectExpressionType.getComponentType()) : objectExpressionType; } @@ -790,7 +794,7 @@ private ClassNode getExpressionType(ClassNode objectExpressionType) { * @param token the token to be checked * @throws SecurityException if usage of this token is forbidden */ - private void assertTokenAuthorized(final Token token) throws SecurityException { + protected void assertTokenAuthorized(final Token token) throws SecurityException { final int value = token.getType(); if (tokensBlacklist != null && tokensBlacklist.contains(value)) { throw new SecurityException("Token " + token + " is not allowed");