Skip to content

Commit

Permalink
Migrating Filter Tests from EasyMock to Mockito
Browse files Browse the repository at this point in the history
  • Loading branch information
bdemers committed Nov 1, 2022
1 parent d9a75a4 commit 37f4791
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,31 @@
import org.apache.shiro.web.mgt.WebSecurityManager;
import org.junit.Test;

import static org.easymock.EasyMock.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.fail;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
import static org.mockito.Mockito.when;

public class GuiceShiroFilterTest {

@Test
public void ensureInjectable() {
try {
InjectionPoint ip = InjectionPoint.forConstructorOf(GuiceShiroFilter.class);
InjectionPoint.forConstructorOf(GuiceShiroFilter.class);
} catch (Exception e) {
fail("Could not create constructor injection point.");
}
}

@Test
public void testConstructor() {
WebSecurityManager securityManager = createMock(WebSecurityManager.class);
FilterChainResolver filterChainResolver = createMock(FilterChainResolver.class);
ShiroFilterConfiguration filterConfiguration = createMock(ShiroFilterConfiguration.class);
expect(filterConfiguration.isStaticSecurityManagerEnabled()).andReturn(true);
expect(filterConfiguration.isFilterOncePerRequest()).andReturn(false);

replay(securityManager, filterChainResolver, filterConfiguration);
WebSecurityManager securityManager = mock(WebSecurityManager.class);
FilterChainResolver filterChainResolver = mock(FilterChainResolver.class);
ShiroFilterConfiguration filterConfiguration = mock(ShiroFilterConfiguration.class);
when(filterConfiguration.isStaticSecurityManagerEnabled()).thenReturn(true);
when(filterConfiguration.isFilterOncePerRequest()).thenReturn(false);

GuiceShiroFilter underTest = new GuiceShiroFilter(securityManager, filterChainResolver, filterConfiguration);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
package org.apache.shiro.web.servlet

import org.apache.shiro.web.config.ShiroFilterConfiguration

import javax.servlet.FilterConfig
import javax.servlet.ServletContext
import org.apache.shiro.web.env.EnvironmentLoader
import org.apache.shiro.web.env.WebEnvironment
import org.apache.shiro.web.filter.mgt.FilterChainResolver
import org.apache.shiro.web.mgt.WebSecurityManager
import org.junit.Test

import static org.easymock.EasyMock.*
import static org.junit.Assert.*
import javax.servlet.FilterConfig
import javax.servlet.ServletContext

import static org.hamcrest.MatcherAssert.assertThat
import static org.hamcrest.Matchers.sameInstance
import static org.mockito.ArgumentMatchers.eq
import static org.mockito.Mockito.mock
import static org.mockito.Mockito.when

/**
* Unit tests for {@link ShiroFilter}.
Expand All @@ -39,91 +42,81 @@ class ShiroFilterTest {
@Test
void testInit() {

def filterConfig = createStrictMock(FilterConfig)
def servletContext = createStrictMock(ServletContext)
def shiroFilterConfig = createStrictMock(ShiroFilterConfiguration)
def webEnvironment = createStrictMock(WebEnvironment)
def webSecurityManager = createStrictMock(WebSecurityManager)
def filterChainResolver = createStrictMock(FilterChainResolver)

expect(filterConfig.servletContext).andReturn(servletContext).anyTimes()
expect(filterConfig.getInitParameter(eq(AbstractShiroFilter.STATIC_INIT_PARAM_NAME))).andReturn null
expect(servletContext.getAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY))).andReturn webEnvironment
expect(shiroFilterConfig.filterOncePerRequest).andReturn true
expect(shiroFilterConfig.staticSecurityManagerEnabled).andReturn false
expect(webEnvironment.shiroFilterConfiguration).andReturn shiroFilterConfig
expect(webEnvironment.webSecurityManager).andReturn webSecurityManager
expect(webEnvironment.filterChainResolver).andReturn filterChainResolver

replay filterConfig, servletContext, webEnvironment, webSecurityManager, filterChainResolver, shiroFilterConfig
def filterConfig = mock(FilterConfig)
def servletContext = mock(ServletContext)
def shiroFilterConfig = mock(ShiroFilterConfiguration)
def webEnvironment = mock(WebEnvironment)
def webSecurityManager = mock(WebSecurityManager)
def filterChainResolver = mock(FilterChainResolver)

when(filterConfig.servletContext).thenReturn(servletContext)
when(filterConfig.getInitParameter(eq(AbstractShiroFilter.STATIC_INIT_PARAM_NAME))).thenReturn null
when(servletContext.getAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY))).thenReturn webEnvironment
when(shiroFilterConfig.filterOncePerRequest).thenReturn true
when(shiroFilterConfig.staticSecurityManagerEnabled).thenReturn false
when(webEnvironment.shiroFilterConfiguration).thenReturn shiroFilterConfig
when(webEnvironment.webSecurityManager).thenReturn webSecurityManager
when(webEnvironment.filterChainResolver).thenReturn filterChainResolver

ShiroFilter filter = new ShiroFilter()

filter.init(filterConfig)

assertSame filter.securityManager, webSecurityManager
assertSame filter.filterChainResolver, filterChainResolver
assertTrue(filter.isFilterOncePerRequest())
assertFalse(filter.isStaticSecurityManagerEnabled())

verify filterConfig, servletContext, webEnvironment, webSecurityManager, filterChainResolver, shiroFilterConfig
assertThat filter.securityManager, sameInstance(webSecurityManager)
assertThat filter.filterChainResolver, sameInstance(filterChainResolver)
assertThat("expected filter.isFilterOncePerRequest() to return true", filter.isFilterOncePerRequest())
assertThat("expected filter.isStaticSecurityManagerEnabled() to return false", !filter.isStaticSecurityManagerEnabled())
}

@Test
void configStaticSecManager_initParm() {

def filterConfig = createStrictMock(FilterConfig)
def servletContext = createStrictMock(ServletContext)
def shiroFilterConfig = createStrictMock(ShiroFilterConfiguration)
def webEnvironment = createStrictMock(WebEnvironment)
def webSecurityManager = createStrictMock(WebSecurityManager)
def filterChainResolver = createStrictMock(FilterChainResolver)

expect(filterConfig.servletContext).andReturn(servletContext).anyTimes()
expect(filterConfig.getInitParameter(eq(AbstractShiroFilter.STATIC_INIT_PARAM_NAME))).andReturn "true"
expect(servletContext.getAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY))).andReturn webEnvironment
expect(shiroFilterConfig.filterOncePerRequest).andReturn false
expect(shiroFilterConfig.staticSecurityManagerEnabled).andReturn false
expect(webEnvironment.shiroFilterConfiguration).andReturn shiroFilterConfig
expect(webEnvironment.webSecurityManager).andReturn webSecurityManager
expect(webEnvironment.filterChainResolver).andReturn filterChainResolver

replay filterConfig, servletContext, webEnvironment, webSecurityManager, filterChainResolver, shiroFilterConfig
def filterConfig = mock(FilterConfig)
def servletContext = mock(ServletContext)
def shiroFilterConfig = mock(ShiroFilterConfiguration)
def webEnvironment = mock(WebEnvironment)
def webSecurityManager = mock(WebSecurityManager)
def filterChainResolver = mock(FilterChainResolver)

when(filterConfig.servletContext).thenReturn(servletContext)
when(filterConfig.getInitParameter(eq(AbstractShiroFilter.STATIC_INIT_PARAM_NAME))).thenReturn "true"
when(servletContext.getAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY))).thenReturn webEnvironment
when(shiroFilterConfig.filterOncePerRequest).thenReturn false
when(shiroFilterConfig.staticSecurityManagerEnabled).thenReturn false
when(webEnvironment.shiroFilterConfiguration).thenReturn shiroFilterConfig
when(webEnvironment.webSecurityManager).thenReturn webSecurityManager
when(webEnvironment.filterChainResolver).thenReturn filterChainResolver

ShiroFilter filter = new ShiroFilter()

filter.init(filterConfig)

assertTrue(filter.isStaticSecurityManagerEnabled())
verify filterConfig, servletContext, webEnvironment, webSecurityManager, filterChainResolver, shiroFilterConfig
assertThat("expected filter.isStaticSecurityManagerEnabled() to return true", filter.isStaticSecurityManagerEnabled())
}

@Test
void configStaticSecManager_config() {

def filterConfig = createStrictMock(FilterConfig)
def servletContext = createStrictMock(ServletContext)
def shiroFilterConfig = createStrictMock(ShiroFilterConfiguration)
def webEnvironment = createStrictMock(WebEnvironment)
def webSecurityManager = createStrictMock(WebSecurityManager)
def filterChainResolver = createStrictMock(FilterChainResolver)

expect(filterConfig.servletContext).andReturn(servletContext).anyTimes()
expect(filterConfig.getInitParameter(eq(AbstractShiroFilter.STATIC_INIT_PARAM_NAME))).andReturn null
expect(servletContext.getAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY))).andReturn webEnvironment
expect(shiroFilterConfig.filterOncePerRequest).andReturn false
expect(shiroFilterConfig.staticSecurityManagerEnabled).andReturn true
expect(webEnvironment.shiroFilterConfiguration).andReturn shiroFilterConfig
expect(webEnvironment.webSecurityManager).andReturn webSecurityManager
expect(webEnvironment.filterChainResolver).andReturn filterChainResolver

replay filterConfig, servletContext, webEnvironment, webSecurityManager, filterChainResolver, shiroFilterConfig
def filterConfig = mock(FilterConfig)
def servletContext = mock(ServletContext)
def shiroFilterConfig = mock(ShiroFilterConfiguration)
def webEnvironment = mock(WebEnvironment)
def webSecurityManager = mock(WebSecurityManager)
def filterChainResolver = mock(FilterChainResolver)

when(filterConfig.servletContext).thenReturn(servletContext)
when(filterConfig.getInitParameter(eq(AbstractShiroFilter.STATIC_INIT_PARAM_NAME))).thenReturn null
when(servletContext.getAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY))).thenReturn webEnvironment
when(shiroFilterConfig.filterOncePerRequest).thenReturn false
when(shiroFilterConfig.staticSecurityManagerEnabled).thenReturn true
when(webEnvironment.shiroFilterConfiguration).thenReturn shiroFilterConfig
when(webEnvironment.webSecurityManager).thenReturn webSecurityManager
when(webEnvironment.filterChainResolver).thenReturn filterChainResolver

ShiroFilter filter = new ShiroFilter()

filter.init(filterConfig)

assertTrue(filter.isStaticSecurityManagerEnabled())
verify filterConfig, servletContext, webEnvironment, webSecurityManager, filterChainResolver, shiroFilterConfig
assertThat("expected filter.isStaticSecurityManagerEnabled() to return true", filter.isStaticSecurityManagerEnabled())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
import javax.servlet.ServletResponse;
import java.io.IOException;

import static org.easymock.EasyMock.*;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Unit tests for the {@link OncePerRequestFilter} implementation.
Expand All @@ -48,9 +49,9 @@ public class OncePerRequestFilterTest {
@Before
public void setUp() {
filter = createTestInstance();
chain = createNiceMock(FilterChain.class);
request = createNiceMock(ServletRequest.class);
response = createNiceMock(ServletResponse.class);
chain = mock(FilterChain.class);
request = mock(ServletRequest.class);
response = mock(ServletResponse.class);
}

private CountingOncePerRequestFilter createTestInstance() {
Expand All @@ -63,12 +64,10 @@ private CountingOncePerRequestFilter createTestInstance() {
@SuppressWarnings({"JavaDoc"})
@Test
public void testEnabled() throws IOException, ServletException {
expect(request.getAttribute(ATTR_NAME)).andReturn(null).anyTimes();
replay(request);
when(request.getAttribute(ATTR_NAME)).thenReturn(null);

filter.doFilter(request, response, chain);

verify(request);
assertEquals("Filter should have executed", 1, filter.filterCount);
}

Expand All @@ -80,26 +79,22 @@ public void testEnabled() throws IOException, ServletException {
public void testDisabled() throws IOException, ServletException {
filter.setEnabled(false); //test disabled

expect(request.getAttribute(ATTR_NAME)).andReturn(null).anyTimes();
replay(request);
when(request.getAttribute(ATTR_NAME)).thenReturn(null);

filter.doFilter(request, response, chain);

verify(request);
assertEquals("Filter should NOT have executed", 0, filter.filterCount);
}

@Test
public void testFilterOncePerRequest() throws IOException, ServletException {
filter.setFilterOncePerRequest(false);

expect(request.getAttribute(ATTR_NAME)).andReturn(null).andReturn(true);
replay(request);
when(request.getAttribute(ATTR_NAME)).thenReturn(null, true);

filter.doFilter(request, response, chain);
filter.doFilter(request, response, chain);

verify(request);
assertEquals("Filter should have executed twice", 2, filter.filterCount);
}

Expand All @@ -116,5 +111,4 @@ protected void doFilterInternal(ServletRequest request, ServletResponse response
filterCount++;
}
}

}

0 comments on commit 37f4791

Please sign in to comment.