Skip to content

Commit

Permalink
GEODE-8782: Add getPrincipal method to FunctionContext interface (#5840)
Browse files Browse the repository at this point in the history
- Add the ability to retrieve the Principal from the FunctionContext
  when a SecurityManager is enabled.
  • Loading branch information
jdeppe-pivotal committed Dec 16, 2020
1 parent 6c6b783 commit a42f89a
Show file tree
Hide file tree
Showing 15 changed files with 279 additions and 19 deletions.
2 changes: 2 additions & 0 deletions buildSrc/src/main/resources/japicmp_exceptions.json
@@ -1,4 +1,6 @@
{
"Class org.apache.geode.net.SSLParameterExtension": "Old implementation exposed an internal class",
"Method org.apache.geode.net.SSLParameterExtension.init(org.apache.geode.distributed.internal.DistributionConfig)": "Old implementation exposed an internal class",
"Class org.apache.geode.cache.execute.FunctionContext": "Interface not intended for client applications",
"Method org.apache.geode.cache.execute.FunctionContext.getPrincipal()": "Interface not intended for client applications"
}
@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more contributor license
* agreements. See the NOTICE file distributed with this work for additional information regarding
* copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/

package org.apache.geode.internal.cache.execute;

import java.util.HashSet;
import java.util.Set;
import java.util.stream.Stream;

import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;

import org.apache.geode.cache.Region;
import org.apache.geode.cache.RegionShortcut;
import org.apache.geode.cache.client.ClientCache;
import org.apache.geode.cache.execute.Function;
import org.apache.geode.cache.execute.FunctionService;
import org.apache.geode.distributed.ConfigurationProperties;
import org.apache.geode.examples.SimpleSecurityManager;
import org.apache.geode.management.internal.security.TestFunctions.ReadFunction;
import org.apache.geode.test.dunit.rules.ClusterStartupRule;
import org.apache.geode.test.dunit.rules.MemberVM;
import org.apache.geode.test.dunit.rules.SerializableFunction;
import org.apache.geode.test.junit.rules.ClientCacheRule;
import org.apache.geode.test.junit.rules.ServerStarterRule;

public class FunctionExecutionWithPrincipalDUnitTest {

private static String PR_REGION_NAME = "partitioned-region";
private static String REGION_NAME = "replicated-region";
private static Region<String, String> replicateRegion;
private static Region<String, String> partitionedRegion;

private static Function<?> readFunction = new ReadFunction();

private static MemberVM locator;
private static MemberVM server1;
private static MemberVM server2;
private static ClientCache client;

@ClassRule
public static ClusterStartupRule cluster = new ClusterStartupRule();

@ClassRule
public static ClientCacheRule clientRule = new ClientCacheRule();

@BeforeClass
public static void beforeClass() throws Exception {
locator = cluster.startLocatorVM(0, x -> x
.withSecurityManager(SimpleSecurityManager.class));
int locatorPort = locator.getPort();

SerializableFunction<ServerStarterRule> startupFunction = x -> x
.withConnectionToLocator(locatorPort)
.withCredential("cluster", "cluster")
.withProperty(ConfigurationProperties.SERIALIZABLE_OBJECT_FILTER,
"org.apache.geode.management.internal.security.TestFunctions*");

server1 = cluster.startServerVM(1, startupFunction);
server2 = cluster.startServerVM(2, startupFunction);

Stream.of(server1, server2).forEach(v -> v.invoke(() -> {
ClusterStartupRule.getCache().createRegionFactory(RegionShortcut.REPLICATE).create(
REGION_NAME);
ClusterStartupRule.getCache().createRegionFactory(RegionShortcut.PARTITION_REDUNDANT)
.create(PR_REGION_NAME);
}));

client = clientRule
.withLocatorConnection(locatorPort)
.withCredential("data", "data")
.createCache();

replicateRegion = clientRule.createProxyRegion(REGION_NAME);
partitionedRegion = clientRule.createProxyRegion(PR_REGION_NAME);

for (int i = 0; i < 10; i++) {
replicateRegion.put("key-" + i, "value-" + i);
partitionedRegion.put("key-" + i, "value-" + i);
}
}

@Test
public void verifyPrincipal_whenUsingReplicateRegion_andCallingOnRegion() {
FunctionService.onRegion(replicateRegion)
.execute(readFunction)
.getResult();
}

@Test
public void verifyPrincipal_whenUsingPartitionedRegion_andCallingOnRegion() {
FunctionService.onRegion(partitionedRegion)
.execute(readFunction)
.getResult();
}

@Test
public void verifyPrincipal_whenUsingPartitionedRegion_andCallingOnRegion_withFilter() {
Set<String> filter = new HashSet<>();
filter.add("key-1");
filter.add("key-2");
filter.add("key-4");
filter.add("key-7");

FunctionService.onRegion(partitionedRegion)
.withFilter(filter)
.execute(readFunction)
.getResult();
}

@Test
public void verifyPrincipal_whenUsingPartitionedRegion_andCallingOnServer() {
FunctionService.onServer(partitionedRegion.getRegionService())
.execute(readFunction)
.getResult();
}

@Test
public void verifyPrincipal_whenUsingPartitionedRegion_andCallingOnServers() {
FunctionService.onServers(partitionedRegion.getRegionService())
.execute(readFunction)
.getResult();
}

@Test
public void verifyPrincipal_whenUsingReplicateRegion_andCallingOnServers() {
FunctionService.onServers(replicateRegion.getRegionService())
.execute(readFunction)
.getResult();
}

}
Expand Up @@ -1297,8 +1297,8 @@ fromData,22
toData,19

org/apache/geode/internal/cache/execute/FunctionRemoteContext,2
fromData,124
toData,102
fromData,145
toData,123

org/apache/geode/internal/cache/ha/HARegionQueue$DispatchedAndCurrentEvents,2
fromData,37
Expand Down
Expand Up @@ -14,10 +14,12 @@
*/
package org.apache.geode.cache.execute;


import org.apache.logging.log4j.util.Strings;

import org.apache.geode.cache.Cache;
import org.apache.geode.distributed.DistributedMember;
import org.apache.geode.internal.security.LegacySecurityService;

/**
* Defines the execution context of a {@link Function}. It is required by the
Expand Down Expand Up @@ -97,4 +99,12 @@ default String getMemberName() {

return member.getId();
}

/**
* If available, returns the principal that has been authenticated to execute this function. This
* will always be null if the {@link LegacySecurityService} is in use.
*
* @return the principal that has been authenticated
*/
Object getPrincipal();
}
Expand Up @@ -3740,7 +3740,7 @@ private ResultCollector executeOnMultipleNodes(final Function function,
FunctionRemoteContext context = new FunctionRemoteContext(function,
execution.getArgumentsForMember(recip.getId()), memKeys,
FunctionExecutionNodePruner.getBucketSet(this, memKeys, false, isBucketSetAsFilter),
execution.isReExecute(), execution.isFnSerializationReqd());
execution.isReExecute(), execution.isFnSerializationReqd(), getPrincipal());
recipMap.put(recip, context);
}
if (logger.isDebugEnabled()) {
Expand All @@ -3755,6 +3755,10 @@ private ResultCollector executeOnMultipleNodes(final Function function,
return localResultCollector;
}

private Object getPrincipal() {
return cache.getSecurityService().getPrincipal();
}

/**
* Single key execution on single node
*
Expand Down Expand Up @@ -3958,7 +3962,7 @@ public ResultCollector executeOnBucketSet(final Function function,
for (InternalDistributedMember recip : dest) {
FunctionRemoteContext context = new FunctionRemoteContext(function,
execution.getArgumentsForMember(recip.getId()), null, memberToBuckets.get(recip),
execution.isReExecute(), execution.isFnSerializationReqd());
execution.isReExecute(), execution.isFnSerializationReqd(), getPrincipal());
recipMap.put(recip, context);
}
final LocalResultCollector<?, ?> localRC = execution.getLocalResultCollector(function, rc);
Expand Down Expand Up @@ -4052,7 +4056,7 @@ private ResultCollector executeOnAllBuckets(final Function function,
for (InternalDistributedMember recip : memberToBuckets.keySet()) {
FunctionRemoteContext context = new FunctionRemoteContext(function,
execution.getArgumentsForMember(recip.getId()), null, memberToBuckets.get(recip),
execution.isReExecute(), execution.isFnSerializationReqd());
execution.isReExecute(), execution.isFnSerializationReqd(), getPrincipal());
recipMap.put(recip, context);
}
final LocalResultCollector<?, ?> localResultCollector =
Expand Down Expand Up @@ -4984,7 +4988,7 @@ private ResultCollector executeFunctionOnRemoteNode(InternalDistributedMember ta
resultSender);

FunctionRemoteContext context = new FunctionRemoteContext(function, object, routingKeys,
bucketArray, execution.isReExecute(), execution.isFnSerializationReqd());
bucketArray, execution.isReExecute(), execution.isFnSerializationReqd(), getPrincipal());

HashMap<InternalDistributedMember, FunctionRemoteContext> recipMap =
new HashMap<InternalDistributedMember, FunctionRemoteContext>();
Expand Down
Expand Up @@ -2969,7 +2969,7 @@ public boolean verifyBucketBeforeGrabbing(final int buckId) {
public void executeOnDataStore(final Set localKeys, final Function function, final Object object,
final int prid, final int[] bucketArray, final boolean isReExecute,
final PartitionedRegionFunctionStreamingMessage msg, long time, ServerConnection servConn,
int transactionID) {
int transactionID, Object principal) {

if (!areAllBucketsHosted(bucketArray)) {
throw new BucketMovedException(
Expand All @@ -2984,7 +2984,7 @@ public void executeOnDataStore(final Set localKeys, final Function function, fin
new RegionFunctionContextImpl(getPartitionedRegion().getCache(), function.getId(),
this.partitionedRegion, object, localKeys, ColocationHelper
.constructAndGetAllColocatedLocalDataSet(this.partitionedRegion, bucketArray),
bucketArray, resultSender, isReExecute);
bucketArray, resultSender, isReExecute, principal);

FunctionStats stats = FunctionStatsManager.getFunctionStats(function.getId(), dm.getSystem());
long start = stats.startFunctionExecution(function.hasResult());
Expand Down
Expand Up @@ -21,6 +21,7 @@
import org.apache.geode.cache.execute.FunctionContext;
import org.apache.geode.cache.execute.RegionFunctionContext;
import org.apache.geode.cache.execute.ResultSender;
import org.apache.geode.internal.cache.InternalCache;

/**
* Context available to application functions which is passed from GemFire to {@link Function}. <br>
Expand All @@ -45,6 +46,8 @@ public class FunctionContextImpl implements FunctionContext {

private final boolean isPossDup;

private final Object principal;

public FunctionContextImpl(final Cache cache, final String functionId, final Object args,
ResultSender resultSender) {
this(cache, functionId, args, resultSender, false);
Expand All @@ -57,6 +60,14 @@ public FunctionContextImpl(final Cache cache, final String functionId, final Obj
this.args = args;
this.resultSender = resultSender;
this.isPossDup = isPossibleDuplicate;

Object tmpPrincipal = null;
if (cache != null) {
if (((InternalCache) cache).getSecurityService() != null) {
tmpPrincipal = ((InternalCache) cache).getSecurityService().getPrincipal();
}
}
this.principal = tmpPrincipal;
}

/**
Expand Down Expand Up @@ -89,6 +100,8 @@ public String toString() {
buf.append(this.functionId);
buf.append(";args=");
buf.append(this.args);
buf.append(";principal=");
buf.append(getPrincipal());
buf.append(']');
return buf.toString();
}
Expand All @@ -111,4 +124,8 @@ public Cache getCache() throws CacheClosedException {
return cache;
}

@Override
public Object getPrincipal() {
return principal;
}
}
Expand Up @@ -31,8 +31,6 @@

/**
* FunctionContext for remote/target nodes
*
*
*/
public class FunctionRemoteContext implements DataSerializable {

Expand All @@ -50,16 +48,19 @@ public class FunctionRemoteContext implements DataSerializable {

private Function function;

private Object principal;

public FunctionRemoteContext() {}

public FunctionRemoteContext(final Function function, Object object, Set filter,
int[] bucketArray, boolean isReExecute, boolean isFnSerializationReqd) {
int[] bucketArray, boolean isReExecute, boolean isFnSerializationReqd, Object principal) {
this.function = function;
this.args = object;
this.filter = filter;
this.bucketArray = bucketArray;
this.isReExecute = isReExecute;
this.isFnSerializationReqd = isFnSerializationReqd;
this.principal = principal;
}

@Override
Expand All @@ -84,6 +85,10 @@ public void fromData(DataInput in) throws IOException, ClassNotFoundException {
this.bucketArray = BucketSetHelper.fromSet(bucketSet);
}
this.isReExecute = DataSerializer.readBoolean(in);

if (StaticSerialization.getVersionForDataStream(in).isNotOlderThan(KnownVersion.GEODE_1_14_0)) {
this.principal = DataSerializer.readObject(in);
}
}

@Override
Expand All @@ -103,6 +108,11 @@ public void toData(DataOutput out) throws IOException {
DataSerializer.writeHashSet((HashSet) bucketSet, out);
}
DataSerializer.writeBoolean(this.isReExecute, out);

if (StaticSerialization.getVersionForDataStream(out)
.isNotOlderThan(KnownVersion.GEODE_1_14_0)) {
DataSerializer.writeObject(this.principal, out);
}
}

public Set getFilter() {
Expand All @@ -129,6 +139,10 @@ public String getFunctionId() {
return functionId;
}

public Object getPrincipal() {
return principal;
}

@Override
public String toString() {

Expand Down

0 comments on commit a42f89a

Please sign in to comment.