From 17f3a8bff79a29ae74a3bc12317519fb50863e86 Mon Sep 17 00:00:00 2001 From: Viacheslav Gagara Date: Thu, 18 May 2017 18:02:01 +0300 Subject: [PATCH] lock RefreshToken entity with pissimistic locking to avoid concurrent updates (unit test added) --- .../oauth2/provider/JPAOAuthDataProvider.java | 34 ++++++++++++ .../code/JPACMTOAuthDataProviderTest.java | 55 ++++++++++++++++++- .../provider/JPAOAuthDataProviderTest.java | 2 +- 3 files changed, 89 insertions(+), 2 deletions(-) diff --git a/rt/rs/security/oauth-parent/oauth2/src/main/java/org/apache/cxf/rs/security/oauth2/provider/JPAOAuthDataProvider.java b/rt/rs/security/oauth-parent/oauth2/src/main/java/org/apache/cxf/rs/security/oauth2/provider/JPAOAuthDataProvider.java index f59b40e858f..d93fd4f6181 100644 --- a/rt/rs/security/oauth-parent/oauth2/src/main/java/org/apache/cxf/rs/security/oauth2/provider/JPAOAuthDataProvider.java +++ b/rt/rs/security/oauth-parent/oauth2/src/main/java/org/apache/cxf/rs/security/oauth2/provider/JPAOAuthDataProvider.java @@ -19,12 +19,16 @@ package org.apache.cxf.rs.security.oauth2.provider; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; + import javax.persistence.EntityManager; import javax.persistence.EntityManagerFactory; import javax.persistence.EntityTransaction; +import javax.persistence.LockModeType; import javax.persistence.TypedQuery; import org.apache.cxf.helpers.CastUtils; @@ -52,6 +56,10 @@ public class JPAOAuthDataProvider extends AbstractOAuthDataProvider { private static final String CLIENT_QUERY = "SELECT client FROM Client client" + " INNER JOIN client.resourceOwnerSubject ros"; + private static final int DEFAULT_PESSIMISTIC_LOCK_TIMEOUT = 10000; + + private int pessimisticLockTimeout = DEFAULT_PESSIMISTIC_LOCK_TIMEOUT; + private EntityManagerFactory entityManagerFactory; public void setEntityManagerFactory(EntityManagerFactory emf) { @@ -213,6 +221,25 @@ public RefreshToken execute(EntityManager em) { }); } + protected void lockRefreshTokenForUpdate(final RefreshToken refreshToken) { + try { + execute(new EntityManagerOperation() { + + @Override + public Void execute(EntityManager em) { + Map options = Collections.emptyMap(); + if (pessimisticLockTimeout > 0) { + Collections.singletonMap("javax.persistence.lock.timeout", pessimisticLockTimeout); + } + em.refresh(refreshToken, LockModeType.PESSIMISTIC_WRITE, options); + return null; + } + }); + } catch (IllegalArgumentException e) { + // entity is not managed yet. ignore + } + } + @Override protected void doRevokeRefreshToken(final RefreshToken rt) { executeInTransaction(new EntityManagerOperation() { @@ -288,6 +315,13 @@ protected void saveRefreshToken(RefreshToken refreshToken) { persistEntity(refreshToken); } + @Override + protected RefreshToken updateRefreshToken(RefreshToken rt, ServerAccessToken at) { + // lock RT for update + lockRefreshTokenForUpdate(rt); + return super.updateRefreshToken(rt, at); + } + protected void persistEntity(final Object entity) { executeInTransaction(new EntityManagerOperation() { @Override diff --git a/rt/rs/security/oauth-parent/oauth2/src/test/java/org/apache/cxf/rs/security/oauth2/grants/code/JPACMTOAuthDataProviderTest.java b/rt/rs/security/oauth-parent/oauth2/src/test/java/org/apache/cxf/rs/security/oauth2/grants/code/JPACMTOAuthDataProviderTest.java index a7245e2695f..73c7d807770 100644 --- a/rt/rs/security/oauth-parent/oauth2/src/test/java/org/apache/cxf/rs/security/oauth2/grants/code/JPACMTOAuthDataProviderTest.java +++ b/rt/rs/security/oauth-parent/oauth2/src/test/java/org/apache/cxf/rs/security/oauth2/grants/code/JPACMTOAuthDataProviderTest.java @@ -18,13 +18,23 @@ */ package org.apache.cxf.rs.security.oauth2.grants.code; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.cxf.rs.security.oauth2.common.AccessTokenRegistration; +import org.apache.cxf.rs.security.oauth2.common.Client; +import org.apache.cxf.rs.security.oauth2.common.ServerAccessToken; import org.apache.cxf.rs.security.oauth2.provider.JPAOAuthDataProvider; import org.apache.cxf.rs.security.oauth2.provider.JPAOAuthDataProviderTest; +import org.apache.cxf.rs.security.oauth2.tokens.refresh.RefreshToken; import org.junit.After; import org.junit.Before; +import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.annotation.DirtiesContext.ClassMode; import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @@ -42,7 +52,7 @@ */ @RunWith(SpringJUnit4ClassRunner.class) @ContextConfiguration("JPACMTCodeDataProvider.xml") -@DirtiesContext +@DirtiesContext(classMode = ClassMode.AFTER_EACH_TEST_METHOD) @ActiveProfiles("hibernate") public class JPACMTOAuthDataProviderTest extends JPAOAuthDataProviderTest { @@ -64,4 +74,47 @@ public void setUp() { @Override public void tearDown() { } + + @Test + public void testRefreshAccessTokenConcurrently() throws Exception { + getProvider().setRecycleRefreshTokens(false); + + Client c = addClient("101", "bob"); + + AccessTokenRegistration atr = new AccessTokenRegistration(); + atr.setClient(c); + atr.setApprovedScope(Arrays.asList("a", "refreshToken")); + atr.setSubject(null); + final ServerAccessToken at = getProvider().createAccessToken(atr); + + Runnable task = new Runnable() { + + @Override + public void run() { + getProvider().refreshAccessToken(c, at.getRefreshToken(), Collections.emptyList()); + } + }; + + Thread th1 = new Thread(task); + Thread th2 = new Thread(task); + Thread th3 = new Thread(task); + + th1.start(); + th2.start(); + th3.start(); + + th1.join(); + th2.join(); + th3.join(); + + assertNotNull(getProvider().getAccessToken(at.getTokenKey())); + List rtl = getProvider().getRefreshTokens(c, null); + assertNotNull(rtl); + assertEquals(1, rtl.size()); + List atl = rtl.get(0).getAccessTokens(); + assertNotNull(atl); + + // after 3 parallel refreshes we should have 4 AccessTokens + assertEquals(4, atl.size()); + } } diff --git a/rt/rs/security/oauth-parent/oauth2/src/test/java/org/apache/cxf/rs/security/oauth2/provider/JPAOAuthDataProviderTest.java b/rt/rs/security/oauth-parent/oauth2/src/test/java/org/apache/cxf/rs/security/oauth2/provider/JPAOAuthDataProviderTest.java index cc0cf2959b7..7094a28705b 100644 --- a/rt/rs/security/oauth-parent/oauth2/src/test/java/org/apache/cxf/rs/security/oauth2/provider/JPAOAuthDataProviderTest.java +++ b/rt/rs/security/oauth-parent/oauth2/src/test/java/org/apache/cxf/rs/security/oauth2/provider/JPAOAuthDataProviderTest.java @@ -263,7 +263,7 @@ public void testAddGetDeleteRefreshToken() { assertNull(getProvider().getRefreshToken(rt.getTokenKey())); } - private Client addClient(String clientId, String userLogin) { + protected Client addClient(String clientId, String userLogin) { Client c = new Client(); c.setRedirectUris(Collections.singletonList("http://client/redirect")); c.setClientId(clientId);