Skip to content

Commit

Permalink
MID-4475 some improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
katkav committed Mar 26, 2018
1 parent 0c0b849 commit c630bd3
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 37 deletions.
Expand Up @@ -22,10 +22,13 @@
import org.apache.cxf.common.util.Base64Utility;
import org.apache.cxf.jaxrs.client.WebClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;

import com.evolveum.midpoint.model.api.ModelInteractionService;
import com.evolveum.midpoint.model.api.ModelService;
import com.evolveum.midpoint.model.impl.security.NodeAuthenticationToken;
import com.evolveum.midpoint.model.impl.security.RestAuthenticationMethod;
import com.evolveum.midpoint.prism.PrismContext;
import com.evolveum.midpoint.prism.PrismObject;
Expand Down Expand Up @@ -76,13 +79,26 @@ public <O extends ObjectType> void invalidateCache(Class<O> type, String oid) {
return;
}

String nodeId = taskManager.getNodeId();

Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication instanceof NodeAuthenticationToken) {
NodeAuthenticationToken nodeAuthenticationToken = (NodeAuthenticationToken) authentication;
PrismObject<NodeType> node = nodeAuthenticationToken.getPrincipal();
if (nodeId.equals(node.asObjectable().getNodeIdentifier())) {
LOGGER.trace("Skipping cluster-wide cache celaring. We are on the same node.");
return;
}
}


Task task = taskManager.createTaskInstance("invalidateCache");
OperationResult result = task.getResult();

SearchResultList<PrismObject<NodeType>> resultList;
try {
String nodeId = taskManager.getNodeId();
ObjectQuery query = QueryBuilder.queryFor(NodeType.class, prismContext).item(NodeType.F_NODE_IDENTIFIER).eq(nodeId).build();

ObjectQuery query = QueryBuilder.queryFor(NodeType.class, prismContext).not().item(NodeType.F_NODE_IDENTIFIER).eq(nodeId).build();
resultList = modelService.searchObjects(NodeType.class, query, null, task, result);
} catch (SchemaException | ObjectNotFoundException | SecurityViolationException | CommunicationException
| ConfigurationException | ExpressionEvaluationException e) {
Expand Down Expand Up @@ -112,12 +128,10 @@ public <O extends ObjectType> void invalidateCache(Class<O> type, String oid) {
for (PrismObject<NodeType> node : resultList.getList()) {
NodeType nodeType = node.asObjectable();

String ipAddress = nodeType.getNodeIdentifier();

String httpPattern = clusterHttpPattern.replace("$host", nodeType.getHostname() + ":8080");
String httpPattern = clusterHttpPattern.replace("$host", nodeType.getHostname());

WebClient client = WebClient.create(httpPattern + "/ws/rest");
client.header("Authorization", RestAuthenticationMethod.CLUSTER + " " + Base64Utility.encode((ipAddress).getBytes()));
client.header("Authorization", RestAuthenticationMethod.CLUSTER.getMethod());// + " " + Base64Utility.encode((nodeIdentifier).getBytes()));

client.path("/event/" + ObjectTypes.getRestTypeFromClass(type));
Response response = client.post(null);
Expand Down
Expand Up @@ -34,6 +34,7 @@
import com.evolveum.midpoint.prism.query.ObjectQuery;
import com.evolveum.midpoint.prism.query.QueryJaxbConvertor;
import com.evolveum.midpoint.prism.query.builder.QueryBuilder;
import com.evolveum.midpoint.repo.api.CacheDispatcher;
import com.evolveum.midpoint.repo.common.CacheRegistry;
import com.evolveum.midpoint.schema.DefinitionProcessingOption;
import com.evolveum.midpoint.schema.DeltaConvertor;
Expand Down Expand Up @@ -126,7 +127,8 @@ public class ModelRestService {
@Autowired private TaskManager taskManager;
@Autowired private Protector protector;
@Autowired private ResourceValidator resourceValidator;
@Autowired private CacheRegistry cacheRegistry;

@Autowired private CacheDispatcher cacheDispatcher;

private static final Trace LOGGER = TraceManager.getTrace(ModelRestService.class);

Expand Down Expand Up @@ -1068,15 +1070,16 @@ public Response executeCredentialReset(@PathParam("oid") String oid, ExecuteCred
@Produces({MediaType.APPLICATION_XML, MediaType.APPLICATION_JSON, "application/yaml"})
public Response executeClusterEvent(@PathParam("type") String type, @Context MessageContext mc) {
//TODO: task??
// Task task = RestServiceUtil.initRequest(mc);
// OperationResult result = task.getResult().createSubresult(OPERATION_EXECUTE_CLUSTER_EVENT);
Task task = RestServiceUtil.initRequest(mc);
OperationResult result = new OperationResult(OPERATION_EXECUTE_CLUSTER_EVENT);
cacheRegistry.clearAllCaches();
String oid = "";
Class clazz = ObjectTypes.getClassFromRestType(type);
cacheDispatcher.dispatch(clazz, oid);

// result.computeStatus();
// finishRequest(task);
result.recordSuccess();
return RestServiceUtil.createResponse(Response.Status.OK, result);
Response response = RestServiceUtil.createResponse(Response.Status.OK, result);
finishRequest(task);
return response;

}

Expand Down
Expand Up @@ -43,7 +43,11 @@
import com.evolveum.midpoint.repo.api.RepositoryService;
import com.evolveum.midpoint.schema.SearchResultList;
import com.evolveum.midpoint.schema.result.OperationResult;
import com.evolveum.midpoint.security.api.HttpConnectionInformation;
import com.evolveum.midpoint.security.api.SecurityContextManager;
import com.evolveum.midpoint.security.api.SecurityUtil;
import com.evolveum.midpoint.task.api.Task;
import com.evolveum.midpoint.task.api.TaskManager;
import com.evolveum.midpoint.util.exception.SchemaException;
import com.evolveum.midpoint.xml.ns._public.common.common_3.NodeType;

Expand All @@ -62,12 +66,14 @@ public class MidpointRestAuthenticationHandler implements ContainerRequestFilter
@Autowired
@Qualifier("cacheRepositoryService")
private RepositoryService repository;
@Autowired private PrismContext prismContext;
@Autowired private SecurityContextManager securityContextManager;

@Autowired private NodeAuthenticator nodeAuthenticator;
@Autowired private TaskManager taskManager;

@Override
public void filter(ContainerRequestContext request, ContainerResponseContext response) throws IOException {
// nothing to do

}

@Override
Expand Down Expand Up @@ -95,6 +101,40 @@ public void filter(ContainerRequestContext requestCtx) throws IOException {
RestServiceUtil.createSecurityQuestionAbortMessage(requestCtx, "{\"user\" : \"username\"}");
return;
}

//TODO: audit login/logout?

if (RestAuthenticationMethod.CLUSTER.equals(authenticationType)) {
HttpConnectionInformation connectionInfo = SecurityUtil.getCurrentConnectionInformation();
String remoteAddress = connectionInfo.getRemoteHostAddress();


if (!nodeAuthenticator.authenticate(null, remoteAddress, "invalidateCache")) {
RestServiceUtil.createAbortMessage(requestCtx);
return;
}
Task task = taskManager.createTaskInstance();
m.put(RestServiceUtil.MESSAGE_PROPERTY_TASK_NAME, task);
// try {
// decodedCredentials = new String(Base64Utility.decode(base64Credentials));
// ObjectQuery query = QueryBuilder.queryFor(NodeType.class, prismContext).item(NodeType.F_NODE_IDENTIFIER).contains(decodedCredentials).build();
// OperationResult result = new OperationResult("authenticate node");
// SearchResultList<PrismObject<NodeType>> nodes = repository.searchObjects(NodeType.class, query, null, result);
// if (nodes.size() != 1) {
// RestServiceUtil.createAbortMessage(requestCtx);
// return;
// }
// //TODO: http header
//
// PreAuthenticatedAuthenticationToken authentication = new PreAuthenticatedAuthenticationToken(nodes.iterator().next(), null);
// SecurityContext securityContext = SecurityContextHolder.getContext();
// securityContext.setAuthentication(authentication);
// } catch (Base64Exception | SchemaException e) {
// RestServiceUtil.createAbortMessage(requestCtx);
// return;
// }
}
return;
}

if (parts.length != 2) {
Expand All @@ -110,34 +150,14 @@ public void filter(ContainerRequestContext requestCtx) throws IOException {
policy.setAuthorizationType(RestAuthenticationMethod.SECURITY_QUESTIONS.getMethod());
policy.setAuthorization(decodedCredentials);
securityQuestionAuthenticator.handleRequest(policy, m, requestCtx);

} catch (Base64Exception e) {
RestServiceUtil.createSecurityQuestionAbortMessage(requestCtx, "{\"user\" : \"username\"}");
return;
}
}

//TODO: audit login/logout?

if (RestAuthenticationMethod.CLUSTER.equals(authenticationType)) {
String decodedCredentials;
try {
decodedCredentials = new String(Base64Utility.decode(base64Credentials));
ObjectQuery query = QueryBuilder.queryFor(NodeType.class, prismContext).item(NodeType.F_NODE_IDENTIFIER).contains(decodedCredentials).build();
OperationResult result = new OperationResult("authenticate node");
SearchResultList<PrismObject<NodeType>> nodes = repository.searchObjects(NodeType.class, query, null, result);
if (nodes.size() != 1) {
RestServiceUtil.createAbortMessage(requestCtx);
return;
}

PreAuthenticatedAuthenticationToken authentication = new PreAuthenticatedAuthenticationToken(nodes.iterator().next(), null);
SecurityContext securityContext = SecurityContextHolder.getContext();
securityContext.setAuthentication(authentication);
} catch (Base64Exception | SchemaException e) {
RestServiceUtil.createAbortMessage(requestCtx);
return;
}
}

}

Expand Down
@@ -0,0 +1,34 @@
package com.evolveum.midpoint.model.impl.security;

import java.util.Collection;

import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.GrantedAuthority;

import com.evolveum.midpoint.prism.PrismObject;
import com.evolveum.midpoint.xml.ns._public.common.common_3.NodeType;

public class NodeAuthenticationToken extends AbstractAuthenticationToken {

private static final long serialVersionUID = 1L;

private PrismObject<NodeType> node;
private String remoteAddress;

public NodeAuthenticationToken(PrismObject<NodeType> node, String remoteAddress, Collection<? extends GrantedAuthority> authorities) {
super(authorities);
this.node = node;
this.remoteAddress = remoteAddress;
}

@Override
public Object getCredentials() {
return remoteAddress;
}

@Override
public PrismObject<NodeType> getPrincipal() {
return node;
}

}
@@ -0,0 +1,112 @@
package com.evolveum.midpoint.model.impl.security;

import java.util.ArrayList;
import java.util.List;

import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;

import com.evolveum.midpoint.prism.PrismContext;
import com.evolveum.midpoint.prism.PrismObject;
import com.evolveum.midpoint.prism.query.ObjectQuery;
import com.evolveum.midpoint.prism.query.builder.QueryBuilder;
import com.evolveum.midpoint.prism.query.builder.S_FilterEntryOrEmpty;
import com.evolveum.midpoint.repo.api.RepositoryService;
import com.evolveum.midpoint.schema.constants.SchemaConstants;
import com.evolveum.midpoint.schema.result.OperationResult;
import com.evolveum.midpoint.security.api.ConnectionEnvironment;
import com.evolveum.midpoint.util.exception.SchemaException;
import com.evolveum.midpoint.util.logging.LoggingUtils;
import com.evolveum.midpoint.util.logging.Trace;
import com.evolveum.midpoint.util.logging.TraceManager;
import com.evolveum.midpoint.xml.ns._public.common.common_3.NodeType;

@Component
public class NodeAuthenticator {

@Autowired
@Qualifier("cacheRepositoryService")
private RepositoryService repositoryService;
@Autowired private PrismContext prismContext;


@Autowired SecurityHelper securityHelper;

private static final Trace LOGGER = TraceManager.getTrace(NodeAuthenticator.class);

private static final String OPERATION_SEARCH_NODE = NodeAuthenticator.class.getName() + ".searchNode";

public boolean authenticate(String remoteName, String remoteAddress, String operation) {
LOGGER.debug("Checking if {} is a known node", remoteName);
OperationResult result = new OperationResult(OPERATION_SEARCH_NODE);

ConnectionEnvironment connEnv = ConnectionEnvironment.create(SchemaConstants.CHANNEL_REST_URI);

try {

List<PrismObject<NodeType>> knownNodes = repositoryService.searchObjects(NodeType.class,
null, null, result);

List<PrismObject<NodeType>> matchingNodes = getMatchingNodes(knownNodes, remoteName, remoteAddress, operation);

if (matchingNodes.size() == 1) {
PrismObject<NodeType> actualNode = knownNodes.iterator().next();
LOGGER.trace(
"The node {} was recognized as a known node (remote host name {} matched). Attempting to execute the requested operation: {} ",
actualNode.asObjectable().getName(), actualNode.asObjectable().getHostname(), operation);
NodeAuthenticationToken authNtoken = new NodeAuthenticationToken(actualNode, remoteAddress,
CollectionUtils.EMPTY_COLLECTION);
SecurityContextHolder.getContext().setAuthentication(authNtoken);
securityHelper.auditLoginSuccess(actualNode.asObjectable(), connEnv);
return true;
}

} catch (RuntimeException | SchemaException e) {
LOGGER.error("Unhandled exception when listing nodes");
LoggingUtils.logUnexpectedException(LOGGER, "Unhandled exception when listing nodes", e);
}
securityHelper.auditLoginFailure(remoteName != null ? remoteName : remoteAddress, null, connEnv, "Failed to authneticate node.");
return false;
}

private List<PrismObject<NodeType>> getMatchingNodes(List<PrismObject<NodeType>> knownNodes, String remoteName, String remoteAddress, String operation) {
List<PrismObject<NodeType>> matchingNodes = new ArrayList<>();
for (PrismObject<NodeType> node : knownNodes) {
NodeType actualNode = node.asObjectable();
if (remoteName != null && remoteName.equalsIgnoreCase(actualNode.getHostname())) {
LOGGER.trace("The node {} was recognized as a known node (remote host name {} matched). Attempting to execute the requested operation: {} ",
actualNode.getName(), actualNode.getHostname(), operation);
matchingNodes.add(node);
continue;
}
if (actualNode.getIpAddress().contains(remoteAddress)) {
LOGGER.trace("The node {} was recognized as a known node (remote host address {} matched). Attempting to execute the requested operation: {} ",
actualNode.getName(), remoteAddress, operation);
matchingNodes.add(node);
continue;
}
}

return matchingNodes;
}

private ObjectQuery createQuery(String remoteName, String remoteAddress) {
S_FilterEntryOrEmpty filterBuilder = QueryBuilder.queryFor(NodeType.class, prismContext);

if (StringUtils.isNotBlank(remoteName)) {
return filterBuilder.item(NodeType.F_HOSTNAME).eq(remoteName).build();
}


if (StringUtils.isNotBlank(remoteAddress)) {
return filterBuilder.item(NodeType.F_IP_ADDRESS).contains(remoteAddress).build();
}

return null;
}

}
Expand Up @@ -53,8 +53,10 @@
import com.evolveum.midpoint.xml.ns._public.common.common_3.CredentialPolicyType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.CredentialsPolicyType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.FocusType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.NodeType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.NonceCredentialsPolicyType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.ObjectReferenceType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.ObjectType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.PasswordCredentialsPolicyType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.PasswordLifeTimeType;
import com.evolveum.midpoint.xml.ns._public.common.common_3.SecurityPolicyType;
Expand Down Expand Up @@ -83,6 +85,10 @@ public class SecurityHelper implements ModelAuditRecorder {
public void auditLoginSuccess(@NotNull UserType user, @NotNull ConnectionEnvironment connEnv) {
auditLogin(user.getName().getOrig(), user, connEnv, OperationResultStatus.SUCCESS, null);
}

public void auditLoginSuccess(@NotNull NodeType node, @NotNull ConnectionEnvironment connEnv) {
auditLogin(node.getName().getOrig(), null, connEnv, OperationResultStatus.SUCCESS, null);
}

@Override
public void auditLoginFailure(@Nullable String username, @Nullable UserType user, @NotNull ConnectionEnvironment connEnv, String message) {
Expand All @@ -100,7 +106,7 @@ private void auditLogin(@Nullable String username, @Nullable UserType user, @Not

AuditEventRecord record = new AuditEventRecord(AuditEventType.CREATE_SESSION, AuditEventStage.REQUEST);
record.setParameter(username);
if (user != null) {
if (user != null ) {
record.setInitiator(user.asPrismObject());
}
record.setTimestamp(System.currentTimeMillis());
Expand Down

0 comments on commit c630bd3

Please sign in to comment.