From 7a3fc4b643ae03d31d188c976843c67fc65067b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?CAS=20in=20the=20cloud=20LELEU=20J=C3=A9r=C3=B4me?= Date: Tue, 25 Jun 2024 15:39:47 +0200 Subject: [PATCH] Fix ConcurrentModificationException on Redis / PT ``` master: https://github.com/apereo/cas/pull/6068 ``` --- .../cas/ticket/TicketGrantingTicketImpl.java | 3 +- .../RedisServerTicketRegistryTests.java | 95 ++++++++++++++++++- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/core/cas-server-core-tickets-api/src/main/java/org/apereo/cas/ticket/TicketGrantingTicketImpl.java b/core/cas-server-core-tickets-api/src/main/java/org/apereo/cas/ticket/TicketGrantingTicketImpl.java index cd5cb6088e46..67a3bc4608ad 100644 --- a/core/cas-server-core-tickets-api/src/main/java/org/apereo/cas/ticket/TicketGrantingTicketImpl.java +++ b/core/cas-server-core-tickets-api/src/main/java/org/apereo/cas/ticket/TicketGrantingTicketImpl.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; /** * Concrete implementation of a TicketGrantingTicket. A TicketGrantingTicket is @@ -55,7 +56,7 @@ public class TicketGrantingTicketImpl extends AbstractTicket implements TicketGr /** * The services associated to this ticket. */ - private Map services = new HashMap<>(0); + private Map services = new ConcurrentHashMap<>(0); /** * The {@link TicketGrantingTicket} this is associated with. diff --git a/support/cas-server-support-redis-ticket-registry/src/test/java/org/apereo/cas/ticket/registry/RedisServerTicketRegistryTests.java b/support/cas-server-support-redis-ticket-registry/src/test/java/org/apereo/cas/ticket/registry/RedisServerTicketRegistryTests.java index 3f82c6d24a44..b2bf05214d45 100644 --- a/support/cas-server-support-redis-ticket-registry/src/test/java/org/apereo/cas/ticket/registry/RedisServerTicketRegistryTests.java +++ b/support/cas-server-support-redis-ticket-registry/src/test/java/org/apereo/cas/ticket/registry/RedisServerTicketRegistryTests.java @@ -1,16 +1,25 @@ package org.apereo.cas.ticket.registry; +import org.apereo.cas.CasProtocolConstants; import org.apereo.cas.authentication.CoreAuthenticationTestUtils; +import org.apereo.cas.authentication.principal.Service; +import org.apereo.cas.authentication.principal.WebApplicationServiceFactory; import org.apereo.cas.config.RedisCoreConfiguration; import org.apereo.cas.config.RedisTicketRegistryConfiguration; import org.apereo.cas.redis.core.CasRedisTemplate; import org.apereo.cas.services.RegisteredServiceTestUtils; +import org.apereo.cas.ticket.ProxyGrantingTicketImpl; import org.apereo.cas.ticket.ServiceTicket; import org.apereo.cas.ticket.Ticket; import org.apereo.cas.ticket.TicketGrantingTicket; import org.apereo.cas.ticket.TicketGrantingTicketImpl; import org.apereo.cas.ticket.expiration.HardTimeoutExpirationPolicy; import org.apereo.cas.ticket.expiration.NeverExpiresExpirationPolicy; +import org.apereo.cas.ticket.proxy.ProxyGrantingTicket; +import org.apereo.cas.ticket.proxy.ProxyTicket; +import org.apereo.cas.ticket.tracking.TicketTrackingPolicy; +import org.apereo.cas.util.ProxyGrantingTicketIdGenerator; +import org.apereo.cas.util.ProxyTicketIdGenerator; import org.apereo.cas.util.ServiceTicketIdGenerator; import org.apereo.cas.util.TicketGrantingTicketIdGenerator; import org.apereo.cas.util.junit.EnabledIfListeningOnPort; @@ -29,6 +38,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.test.context.TestPropertySource; import java.util.ArrayList; @@ -345,7 +355,7 @@ private void addTicketAndWait(final String principalId) { "cas.ticket.registry.redis.host=localhost", "cas.ticket.registry.redis.port=6379" }) - class ConcurrentTests { + class ConcurrentAddTicketGrantingTicketTests { @Autowired @Qualifier(TicketRegistry.BEAN_NAME) private TicketRegistry ticketRegistry; @@ -356,7 +366,7 @@ void verifyConcurrentAddTicket() throws Throwable { val testHasFailed = new AtomicBoolean(); val threads = new ArrayList(); for (var i = 1; i <= 3; i++) { - val runnable = new RunnableAddTicket(ticketRegistry, principalId, 100); + val runnable = new RunnableAddTicketGrantingTicket(ticketRegistry, principalId, 100); val thread = new Thread(runnable); thread.setName("Thread-" + i); thread.setUncaughtExceptionHandler((t, e) -> { @@ -379,7 +389,7 @@ void verifyConcurrentAddTicket() throws Throwable { } @RequiredArgsConstructor - private static final class RunnableAddTicket implements Runnable { + private static final class RunnableAddTicketGrantingTicket implements Runnable { private final TicketRegistry ticketRegistry; private final String principalId; private final int max; @@ -397,4 +407,83 @@ public void run() { } } } + + @Nested + @SpringBootTest( + classes = { + RedisCoreConfiguration.class, + RedisTicketRegistryConfiguration.class, + BaseTicketRegistryTests.SharedTestConfiguration.class + }, properties = { + "cas.ticket.tgt.core.only-track-most-recent-session=false", + "cas.ticket.registry.redis.host=localhost", + "cas.ticket.registry.redis.port=6379" + }) + class ConcurrentAddProxyTicketTests { + @Autowired + @Qualifier(TicketRegistry.BEAN_NAME) + private TicketRegistry ticketRegistry; + + @Autowired + @Qualifier(org.apereo.cas.ticket.tracking.TicketTrackingPolicy.BEAN_NAME_SERVICE_TICKET_TRACKING) + private TicketTrackingPolicy serviceTicketSessionTrackingPolicy; + + @Test + void verifyConcurrentAddTicket() throws Throwable { + val principalId = UUID.randomUUID().toString(); + val authentication = CoreAuthenticationTestUtils.getAuthentication(principalId); + val tgtGenerator = new ProxyGrantingTicketIdGenerator(10, StringUtils.EMPTY); + val pgt = new ProxyGrantingTicketImpl(tgtGenerator.getNewTicketId(TicketGrantingTicket.PREFIX), + authentication, NeverExpiresExpirationPolicy.INSTANCE); + ticketRegistry.addTicket(pgt); + + val request = new MockHttpServletRequest(); + request.setParameter(CasProtocolConstants.PARAMETER_SERVICE, "http://foo.com"); + val service = new WebApplicationServiceFactory().createService(request); + + val testHasFailed = new AtomicBoolean(); + val threads = new ArrayList(); + for (var i = 1; i <= 3; i++) { + val runnable = new RunnableAddProxyTicket(ticketRegistry, pgt, service, serviceTicketSessionTrackingPolicy, 100); + val thread = new Thread(runnable); + thread.setName("Thread-" + i); + thread.setUncaughtExceptionHandler((t, e) -> { + LOGGER.error(e.getMessage(), e); + testHasFailed.set(true); + }); + threads.add(thread); + thread.start(); + } + for (val thread : threads) { + try { + thread.join(); + } catch (final Throwable e) { + fail(e); + } + } + if (testHasFailed.get()) { + fail("Test failed"); + } + } + + @RequiredArgsConstructor + private static final class RunnableAddProxyTicket implements Runnable { + private final TicketRegistry ticketRegistry; + private final ProxyGrantingTicket proxyGrantingTicket; + private final Service service; + private final TicketTrackingPolicy serviceTicketSessionTrackingPolicy; + private final int max; + + @Override + @SneakyThrows + public void run() { + val ptGenerator = new ProxyTicketIdGenerator(10, StringUtils.EMPTY); + for (int i = 0; i < max; i++) { + val proxyTicket = proxyGrantingTicket.grantProxyTicket(ptGenerator.getNewTicketId(ProxyTicket.PREFIX), + service, new HardTimeoutExpirationPolicy(20), serviceTicketSessionTrackingPolicy); + ticketRegistry.addTicket(proxyTicket); + } + } + } + } }