Skip to content

Commit

Permalink
Fix race condition in custom library invalidation
Browse files Browse the repository at this point in the history
During custom function libraries cache invalidation (in
ScriptExpressionFactory) there was a short time window during which
an empty cache could be used by mistake. It was between checking if
the cache is initialized and actually using its content. It is now
fixed.

This should resolve MID-8137.
  • Loading branch information
mederly committed Oct 4, 2022
1 parent e80d0d9 commit c964eff
Show file tree
Hide file tree
Showing 12 changed files with 418 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
package com.evolveum.midpoint.model.common.expression.script;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
Expand Down Expand Up @@ -45,6 +43,9 @@
import com.evolveum.midpoint.xml.ns._public.common.common_3.ScriptExpressionEvaluatorType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.SingleCacheStateInformationType;

import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

import static com.evolveum.midpoint.schema.GetOperationOptions.createReadOnlyCollection;

/**
Expand All @@ -57,80 +58,114 @@ public class ScriptExpressionFactory implements Cache {

private static final String DEFAULT_LANGUAGE = "http://midpoint.evolveum.com/xml/ns/public/expression/language#Groovy";

private final Map<String, ScriptEvaluator> evaluatorMap = new HashMap<>();
private ObjectResolver objectResolver;
private final PrismContext prismContext;
private Collection<FunctionLibrary> functions;
private final RepositoryService repositoryService; // might be null during low-level testing
@NotNull private final Map<String, ScriptEvaluator> evaluatorMap = new HashMap<>();
@NotNull private final ObjectResolver objectResolver;
@NotNull private final PrismContext prismContext;

/** Null only in low-level tests. */
@Nullable private final RepositoryService repositoryService;

@NotNull private final Map<String, FunctionLibrary> customFunctionLibraryCache = new ConcurrentHashMap<>();
private final AtomicBoolean initialized = new AtomicBoolean(false);
/** Null only in low-level tests. */
@Nullable private final CacheRegistry cacheRegistry;

private CacheRegistry cacheRegistry;
/** Initialized at startup. The collection is immutable. */
@NotNull private final Collection<FunctionLibrary> standardFunctionLibraries;

/** The collection is immutable. */
private volatile Collection<FunctionLibrary> cachedCustomFunctionLibraries;

@PostConstruct
public void register() {
cacheRegistry.registerCache(this);
if (cacheRegistry != null) {
cacheRegistry.registerCache(this);
}
}

@PreDestroy
public void unregister() {
cacheRegistry.unregisterCache(this);
if (cacheRegistry != null) {
cacheRegistry.unregisterCache(this);
}
}

public ScriptExpressionFactory(PrismContext prismContext, RepositoryService repositoryService) {
public ScriptExpressionFactory(
@NotNull PrismContext prismContext,
@NotNull RepositoryService repositoryService,
@NotNull Collection<FunctionLibrary> standardFunctionLibraries,
@NotNull Collection<ScriptEvaluator> evaluators,
@NotNull CacheRegistry cacheRegistry,
@NotNull ObjectResolver objectResolver) {
this.prismContext = prismContext;
this.repositoryService = repositoryService;
}

public ObjectResolver getObjectResolver() {
return objectResolver;
this.repositoryService = Objects.requireNonNull(repositoryService);
this.standardFunctionLibraries = Collections.unmodifiableCollection(standardFunctionLibraries);
registerEvaluators(evaluators);
this.cacheRegistry = Objects.requireNonNull(cacheRegistry); // Important to be non-null to ensure consistency
this.objectResolver = objectResolver;
}

public void setObjectResolver(ObjectResolver objectResolver) {
@VisibleForTesting
public ScriptExpressionFactory(
@NotNull Collection<FunctionLibrary> standardFunctionLibraries,
@NotNull ObjectResolver objectResolver) {
this.prismContext = PrismContext.get();
this.repositoryService = null;
this.standardFunctionLibraries = Collections.unmodifiableCollection(standardFunctionLibraries);
this.cacheRegistry = null;
this.objectResolver = objectResolver;
}

public void setEvaluators(Collection<ScriptEvaluator> evaluators) {
private void registerEvaluators(@NotNull Collection<ScriptEvaluator> evaluators) {
for (ScriptEvaluator evaluator : evaluators) {
registerEvaluator(evaluator.getLanguageUrl(), evaluator);
registerEvaluator(evaluator);
}
}

public Collection<FunctionLibrary> getFunctions() {
return Collections.unmodifiableCollection(functions); // MID-4396
@VisibleForTesting
public void registerEvaluator(ScriptEvaluator evaluator) {
registerEvaluator(evaluator.getLanguageUrl(), evaluator);
}

public void setFunctions(Collection<FunctionLibrary> functions) {
this.functions = functions;
private void registerEvaluator(String language, ScriptEvaluator evaluator) {
if (evaluatorMap.containsKey(language)) {
throw new IllegalArgumentException("Evaluator for language " + language + " already registered");
}
evaluatorMap.put(language, evaluator);
}

public Map<String, ScriptEvaluator> getEvaluators() {
return evaluatorMap;
@VisibleForTesting
public @NotNull ObjectResolver getObjectResolver() {
return objectResolver;
}

@VisibleForTesting
@NotNull Collection<FunctionLibrary> getStandardFunctionLibraries() {
return standardFunctionLibraries;
}

public void setCacheRegistry(CacheRegistry registry) {
this.cacheRegistry = registry;
@VisibleForTesting
public @NotNull Map<String, ScriptEvaluator> getEvaluators() {
return evaluatorMap;
}

public ScriptExpression createScriptExpression(
ScriptExpressionEvaluatorType expressionType, ItemDefinition<?> outputDefinition,
ExpressionProfile expressionProfile, ExpressionFactory expressionFactory,
String shortDesc, OperationResult result)
ScriptExpressionEvaluatorType expressionType,
ItemDefinition<?> outputDefinition,
ExpressionProfile expressionProfile,
ExpressionFactory expressionFactory,
String shortDesc,
OperationResult result)
throws ExpressionSyntaxException, SecurityViolationException {

initializeCustomFunctionsLibraryCacheIfNeeded(expressionFactory, result);
//cache cleanup method

String language = getLanguage(expressionType);
ScriptEvaluator evaluator = getEvaluator(language, shortDesc);
ScriptExpression expression = new ScriptExpression(evaluator, expressionType);
expression.setPrismContext(prismContext);
expression.setOutputDefinition(outputDefinition);
expression.setObjectResolver(objectResolver);
Collection<FunctionLibrary> functionsToUse = new ArrayList<>(functions);
functionsToUse.addAll(customFunctionLibraryCache.values());
expression.setFunctions(functionsToUse);
Collection<FunctionLibrary> allFunctionLibraries = new ArrayList<>(standardFunctionLibraries);
allFunctionLibraries.addAll(
getCustomFunctionLibraries(expressionFactory, result));
expression.setFunctions(allFunctionLibraries);

// It is not very elegant to process expression profile and script expression profile here.
// It is somehow redundant, as it was already pre-processed in the expression evaluator/factory
Expand All @@ -147,7 +182,8 @@ public ScriptExpression createScriptExpression(
return expression;
}

private ScriptExpressionProfile processScriptExpressionProfile(ExpressionProfile expressionProfile, String language, String shortDesc) throws SecurityViolationException {
private ScriptExpressionProfile processScriptExpressionProfile(
ExpressionProfile expressionProfile, String language, String shortDesc) throws SecurityViolationException {
if (expressionProfile == null) {
return null;
}
Expand All @@ -172,49 +208,64 @@ private ScriptExpressionProfile processScriptExpressionProfile(ExpressionProfile
return scriptProfile;
}

private void initializeCustomFunctionsLibraryCacheIfNeeded(ExpressionFactory expressionFactory, OperationResult result)
private @NotNull Collection<FunctionLibrary> getCustomFunctionLibraries(
ExpressionFactory expressionFactory, OperationResult result)
throws ExpressionSyntaxException {
if (initialized.compareAndSet(false, true)) {
initializeCustomFunctionsLibraryCache(expressionFactory, result);
Collection<FunctionLibrary> current = cachedCustomFunctionLibraries;
if (current != null) {
return current;
}
}

private void initializeCustomFunctionsLibraryCache(ExpressionFactory expressionFactory, OperationResult result)
throws ExpressionSyntaxException {
if (repositoryService != null) {
OperationResult subResult = result
.createMinorSubresult(ScriptExpressionFactory.class.getName() + ".searchCustomFunctions");
ResultHandler<FunctionLibraryType> functionLibraryHandler = (object, parentResult) -> {
// TODO: determine profile from function library archetype
ExpressionProfile expressionProfile = MiscSchemaUtil.getExpressionProfile();
FunctionLibrary customLibrary = new FunctionLibrary();
customLibrary.setVariableName(object.getName().getOrig());
customLibrary.setGenericFunctions(
new CustomFunctions(object.asObjectable(), expressionFactory, expressionProfile));
customLibrary.setNamespace(MidPointConstants.NS_FUNC_CUSTOM);
customFunctionLibraryCache.put(object.getName().getOrig(), customLibrary);
return true;
};
try {
repositoryService.searchObjectsIterative(FunctionLibraryType.class, null, functionLibraryHandler,
createReadOnlyCollection(), true, subResult);
subResult.recordSuccessIfUnknown();
} catch (SchemaException | RuntimeException e) {
subResult.recordFatalError("Failed to initialize custom functions", e);
throw new ExpressionSyntaxException(
"An error occurred during custom libraries initialization. " + e.getMessage(), e);
}
} else {
if (repositoryService == null) {
LOGGER.warn("No repository service set for ScriptExpressionFactory; custom functions will not be loaded. This"
+ " can occur during low-level testing; never during standard system execution.");
return List.of(); // intentionally not caching this value
}

Collection<FunctionLibrary> fetched = fetchCustomFunctionLibraries(expressionFactory, result);
cachedCustomFunctionLibraries = fetched;
return fetched;
}

public void registerEvaluator(String language, ScriptEvaluator evaluator) {
if (evaluatorMap.containsKey(language)) {
throw new IllegalArgumentException("Evaluator for language " + language + " already registered");
private @NotNull Collection<FunctionLibrary> fetchCustomFunctionLibraries(
ExpressionFactory expressionFactory, OperationResult result)
throws ExpressionSyntaxException {
assert repositoryService != null;
Map<String, FunctionLibrary> customLibrariesMap = new HashMap<>();
ResultHandler<FunctionLibraryType> functionLibraryHandler = (object, parentResult) -> {
LOGGER.trace("Found {}", object);
// TODO: determine profile from function library archetype
ExpressionProfile expressionProfile = MiscSchemaUtil.getExpressionProfile();
FunctionLibrary customLibrary = new FunctionLibrary();
String libraryName = object.getName().getOrig();
customLibrary.setVariableName(libraryName);
customLibrary.setGenericFunctions(
new CustomFunctions(object.asObjectable(), expressionFactory, expressionProfile));
customLibrary.setNamespace(MidPointConstants.NS_FUNC_CUSTOM);
FunctionLibrary existing = customLibrariesMap.get(libraryName);
if (existing != null) {
LOGGER.warn("Multiple custom libraries with the name of '{}'? {} and {}", libraryName, existing, customLibrary);
}
customLibrariesMap.put(libraryName, customLibrary);
return true;
};
OperationResult subResult = result
.createMinorSubresult(ScriptExpressionFactory.class.getName() + ".searchCustomFunctions");
try {
LOGGER.trace("Searching for function libraries");
repositoryService.searchObjectsIterative(FunctionLibraryType.class, null, functionLibraryHandler,
createReadOnlyCollection(), true, subResult);
} catch (SchemaException | RuntimeException e) {
subResult.recordFatalError("Failed to initialize custom functions", e);
throw new ExpressionSyntaxException(
"An error occurred during custom libraries initialization. " + e.getMessage(), e);
} finally {
subResult.close();
}
evaluatorMap.put(language, evaluator);
LOGGER.debug("Function libraries found: {}", customLibrariesMap.size());
return Collections.unmodifiableCollection(
new ArrayList<>(
customLibrariesMap.values()));
}

private @NotNull ScriptEvaluator getEvaluator(String languageUri, String shortDesc) throws ExpressionSyntaxException {
Expand Down Expand Up @@ -247,25 +298,26 @@ private String getLanguage(ScriptExpressionEvaluatorType expressionType) {
@Override
public void invalidate(Class<?> type, String oid, CacheInvalidationContext context) {
if (type == null || type.isAssignableFrom(FunctionLibraryType.class)) {
LOGGER.trace("Invalidating custom functions library cache");
// Currently we don't try to select entries to be cleared based on OID
customFunctionLibraryCache.clear();
initialized.set(false);
cachedCustomFunctionLibraries = null;
}
}

@NotNull
@Override
public Collection<SingleCacheStateInformationType> getStateInformation() {
return Collections.singleton(new SingleCacheStateInformationType(prismContext)
return Collections.singleton(new SingleCacheStateInformationType()
.name(ScriptExpressionFactory.class.getName())
.size(customFunctionLibraryCache.size()));
.size(cachedCustomFunctionLibraries.size()));
}

@Override
public void dumpContent() {
if (LOGGER_CONTENT.isInfoEnabled()) {
if (initialized.get()) {
customFunctionLibraryCache.forEach((k, v) -> LOGGER_CONTENT.info("Cached function library: {}: {}", k, v));
Collection<FunctionLibrary> cached = cachedCustomFunctionLibraries;
if (cached != null) {
cached.forEach(v -> LOGGER_CONTENT.info("Cached function library: {}", v));
} else {
LOGGER_CONTENT.info("Custom function library cache is not yet initialized");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,18 @@
import com.evolveum.midpoint.prism.PrismContext;
import com.evolveum.midpoint.prism.crypto.KeyStoreBasedProtectorBuilder;
import com.evolveum.midpoint.prism.crypto.Protector;
import com.evolveum.midpoint.repo.api.RepositoryService;
import com.evolveum.midpoint.repo.common.ObjectResolver;
import com.evolveum.midpoint.repo.common.expression.ExpressionFactory;
import com.evolveum.midpoint.repo.common.expression.evaluator.AsIsExpressionEvaluatorFactory;
import com.evolveum.midpoint.repo.common.expression.evaluator.LiteralExpressionEvaluatorFactory;
import com.evolveum.midpoint.security.api.SecurityContextManager;
import com.evolveum.midpoint.test.util.MidPointTestConstants;

/**
* @author Radovan Semancik
*/
public class ExpressionTestUtil {

public static final String CONST_FOO_NAME = "foo";
private static final String CONST_FOO_NAME = "foo";
public static final String CONST_FOO_VALUE = "foobar";

public static Protector createInitializedProtector(PrismContext prismContext) {
Expand All @@ -52,10 +50,9 @@ public static Protector createInitializedProtector(PrismContext prismContext) {
}

public static ExpressionFactory createInitializedExpressionFactory(
ObjectResolver resolver, Protector protector, PrismContext prismContext, Clock clock,
SecurityContextManager securityContextManager, RepositoryService repositoryService) {
ObjectResolver resolver, Protector protector, PrismContext prismContext, Clock clock) {
ExpressionFactory expressionFactory = new ExpressionFactory(
securityContextManager, prismContext, LocalizationTestUtil.getLocalizationService());
null, prismContext, LocalizationTestUtil.getLocalizationService());
expressionFactory.setObjectResolver(resolver);

// NOTE: we need to register the evaluator factories to expressionFactory manually here
Expand Down Expand Up @@ -97,24 +94,21 @@ public static ExpressionFactory createInitializedExpressionFactory(
Collection<FunctionLibrary> functions = new ArrayList<>();
functions.add(FunctionLibraryUtil.createBasicFunctionLibrary(prismContext, protector, clock));
functions.add(FunctionLibraryUtil.createLogFunctionLibrary(prismContext));
ScriptExpressionFactory scriptExpressionFactory =
new ScriptExpressionFactory(prismContext, repositoryService);
scriptExpressionFactory.setObjectResolver(resolver);
scriptExpressionFactory.setFunctions(functions);
ScriptExpressionFactory scriptExpressionFactory = new ScriptExpressionFactory(functions, resolver);

GroovyScriptEvaluator groovyEvaluator = new GroovyScriptEvaluator(
prismContext, protector, LocalizationTestUtil.getLocalizationService());
scriptExpressionFactory.registerEvaluator(groovyEvaluator.getLanguageUrl(), groovyEvaluator);
scriptExpressionFactory.registerEvaluator(groovyEvaluator);

Jsr223ScriptEvaluator jsEvaluator = new Jsr223ScriptEvaluator(
"ECMAScript", prismContext, protector, LocalizationTestUtil.getLocalizationService());
if (jsEvaluator.isInitialized()) {
scriptExpressionFactory.registerEvaluator(jsEvaluator.getLanguageUrl(), jsEvaluator);
scriptExpressionFactory.registerEvaluator(jsEvaluator);
}

ScriptExpressionEvaluatorFactory scriptExpressionEvaluatorFactory =
new ScriptExpressionEvaluatorFactory(
scriptExpressionFactory, securityContextManager, prismContext);
scriptExpressionFactory, null, prismContext);
expressionFactory.registerEvaluatorFactory(scriptExpressionEvaluatorFactory);

return expressionFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ public void setup() throws SchemaException, SAXException, IOException {
Protector protector = ExpressionTestUtil.createInitializedProtector(prismContext);
Clock clock = new Clock();
constantsManager = new ConstantsManager();
expressionFactory = ExpressionTestUtil.createInitializedExpressionFactory(
resolver, protector, prismContext, clock, null, null);
expressionFactory = ExpressionTestUtil.createInitializedExpressionFactory(resolver, protector, prismContext, clock);

expressionProfile = compileExpressionProfile(getExpressionProfileName());
System.out.println("Using expression profile: " + expressionProfile);
Expand Down

0 comments on commit c964eff

Please sign in to comment.