Skip to content

Commit

Permalink
Fix ConcurrentModificationException on Redis / PT
Browse files Browse the repository at this point in the history
```
master: #6068
```
  • Loading branch information
leleuj committed Jun 25, 2024
1 parent ee7e65e commit 7a3fc4b
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
* Concrete implementation of a TicketGrantingTicket. A TicketGrantingTicket is
Expand Down Expand Up @@ -55,7 +56,7 @@ public class TicketGrantingTicketImpl extends AbstractTicket implements TicketGr
/**
* The services associated to this ticket.
*/
private Map<String, Service> services = new HashMap<>(0);
private Map<String, Service> services = new ConcurrentHashMap<>(0);

/**
* The {@link TicketGrantingTicket} this is associated with.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
package org.apereo.cas.ticket.registry;

import org.apereo.cas.CasProtocolConstants;
import org.apereo.cas.authentication.CoreAuthenticationTestUtils;
import org.apereo.cas.authentication.principal.Service;
import org.apereo.cas.authentication.principal.WebApplicationServiceFactory;
import org.apereo.cas.config.RedisCoreConfiguration;
import org.apereo.cas.config.RedisTicketRegistryConfiguration;
import org.apereo.cas.redis.core.CasRedisTemplate;
import org.apereo.cas.services.RegisteredServiceTestUtils;
import org.apereo.cas.ticket.ProxyGrantingTicketImpl;
import org.apereo.cas.ticket.ServiceTicket;
import org.apereo.cas.ticket.Ticket;
import org.apereo.cas.ticket.TicketGrantingTicket;
import org.apereo.cas.ticket.TicketGrantingTicketImpl;
import org.apereo.cas.ticket.expiration.HardTimeoutExpirationPolicy;
import org.apereo.cas.ticket.expiration.NeverExpiresExpirationPolicy;
import org.apereo.cas.ticket.proxy.ProxyGrantingTicket;
import org.apereo.cas.ticket.proxy.ProxyTicket;
import org.apereo.cas.ticket.tracking.TicketTrackingPolicy;
import org.apereo.cas.util.ProxyGrantingTicketIdGenerator;
import org.apereo.cas.util.ProxyTicketIdGenerator;
import org.apereo.cas.util.ServiceTicketIdGenerator;
import org.apereo.cas.util.TicketGrantingTicketIdGenerator;
import org.apereo.cas.util.junit.EnabledIfListeningOnPort;
Expand All @@ -29,6 +38,7 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.test.context.TestPropertySource;

import java.util.ArrayList;
Expand Down Expand Up @@ -345,7 +355,7 @@ private void addTicketAndWait(final String principalId) {
"cas.ticket.registry.redis.host=localhost",
"cas.ticket.registry.redis.port=6379"
})
class ConcurrentTests {
class ConcurrentAddTicketGrantingTicketTests {
@Autowired
@Qualifier(TicketRegistry.BEAN_NAME)
private TicketRegistry ticketRegistry;
Expand All @@ -356,7 +366,7 @@ void verifyConcurrentAddTicket() throws Throwable {
val testHasFailed = new AtomicBoolean();
val threads = new ArrayList<Thread>();
for (var i = 1; i <= 3; i++) {
val runnable = new RunnableAddTicket(ticketRegistry, principalId, 100);
val runnable = new RunnableAddTicketGrantingTicket(ticketRegistry, principalId, 100);
val thread = new Thread(runnable);
thread.setName("Thread-" + i);
thread.setUncaughtExceptionHandler((t, e) -> {
Expand All @@ -379,7 +389,7 @@ void verifyConcurrentAddTicket() throws Throwable {
}

@RequiredArgsConstructor
private static final class RunnableAddTicket implements Runnable {
private static final class RunnableAddTicketGrantingTicket implements Runnable {
private final TicketRegistry ticketRegistry;
private final String principalId;
private final int max;
Expand All @@ -397,4 +407,83 @@ public void run() {
}
}
}

@Nested
@SpringBootTest(
classes = {
RedisCoreConfiguration.class,
RedisTicketRegistryConfiguration.class,
BaseTicketRegistryTests.SharedTestConfiguration.class
}, properties = {
"cas.ticket.tgt.core.only-track-most-recent-session=false",
"cas.ticket.registry.redis.host=localhost",
"cas.ticket.registry.redis.port=6379"
})
class ConcurrentAddProxyTicketTests {
@Autowired
@Qualifier(TicketRegistry.BEAN_NAME)
private TicketRegistry ticketRegistry;

@Autowired
@Qualifier(org.apereo.cas.ticket.tracking.TicketTrackingPolicy.BEAN_NAME_SERVICE_TICKET_TRACKING)
private TicketTrackingPolicy serviceTicketSessionTrackingPolicy;

@Test
void verifyConcurrentAddTicket() throws Throwable {
val principalId = UUID.randomUUID().toString();
val authentication = CoreAuthenticationTestUtils.getAuthentication(principalId);
val tgtGenerator = new ProxyGrantingTicketIdGenerator(10, StringUtils.EMPTY);
val pgt = new ProxyGrantingTicketImpl(tgtGenerator.getNewTicketId(TicketGrantingTicket.PREFIX),
authentication, NeverExpiresExpirationPolicy.INSTANCE);
ticketRegistry.addTicket(pgt);

val request = new MockHttpServletRequest();
request.setParameter(CasProtocolConstants.PARAMETER_SERVICE, "http://foo.com");
val service = new WebApplicationServiceFactory().createService(request);

val testHasFailed = new AtomicBoolean();
val threads = new ArrayList<Thread>();
for (var i = 1; i <= 3; i++) {
val runnable = new RunnableAddProxyTicket(ticketRegistry, pgt, service, serviceTicketSessionTrackingPolicy, 100);
val thread = new Thread(runnable);
thread.setName("Thread-" + i);
thread.setUncaughtExceptionHandler((t, e) -> {
LOGGER.error(e.getMessage(), e);
testHasFailed.set(true);
});
threads.add(thread);
thread.start();
}
for (val thread : threads) {
try {
thread.join();
} catch (final Throwable e) {
fail(e);
}
}
if (testHasFailed.get()) {
fail("Test failed");
}
}

@RequiredArgsConstructor
private static final class RunnableAddProxyTicket implements Runnable {
private final TicketRegistry ticketRegistry;
private final ProxyGrantingTicket proxyGrantingTicket;
private final Service service;
private final TicketTrackingPolicy serviceTicketSessionTrackingPolicy;
private final int max;

@Override
@SneakyThrows
public void run() {
val ptGenerator = new ProxyTicketIdGenerator(10, StringUtils.EMPTY);
for (int i = 0; i < max; i++) {
val proxyTicket = proxyGrantingTicket.grantProxyTicket(ptGenerator.getNewTicketId(ProxyTicket.PREFIX),
service, new HardTimeoutExpirationPolicy(20), serviceTicketSessionTrackingPolicy);
ticketRegistry.addTicket(proxyTicket);
}
}
}
}
}

0 comments on commit 7a3fc4b

Please sign in to comment.