-
Notifications
You must be signed in to change notification settings - Fork 188
/
MidpointAuthFilter.java
344 lines (289 loc) · 16.8 KB
/
MidpointAuthFilter.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
/*
* Copyright (c) 2010-2019 Evolveum and contributors
*
* This work is dual-licensed under the Apache License 2.0
* and European Union Public License. See LICENSE file for details.
*/
package com.evolveum.midpoint.web.security.filter;
import com.evolveum.midpoint.model.api.authentication.*;
import com.evolveum.midpoint.model.common.SystemObjectCache;
import com.evolveum.midpoint.prism.PrismContext;
import com.evolveum.midpoint.prism.PrismObject;
import com.evolveum.midpoint.schema.result.OperationResult;
import com.evolveum.midpoint.schema.util.SecurityPolicyUtil;
import com.evolveum.midpoint.util.exception.SchemaException;
import com.evolveum.midpoint.util.logging.Trace;
import com.evolveum.midpoint.util.logging.TraceManager;
import com.evolveum.midpoint.web.security.MidpointAuthenticationManager;
import com.evolveum.midpoint.web.security.factory.channel.AuthChannelRegistryImpl;
import com.evolveum.midpoint.web.security.module.ModuleWebSecurityConfig;
import com.evolveum.midpoint.web.security.factory.module.AuthModuleRegistryImpl;
import com.evolveum.midpoint.web.security.util.SecurityUtils;
import com.evolveum.midpoint.xml.ns._public.common.common_3.*;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.config.annotation.ObjectPostProcessor;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.WebAttributes;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.web.filter.GenericFilterBean;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.*;
/**
* @author skublik
*/
public class MidpointAuthFilter extends GenericFilterBean {
private static final Trace LOGGER = TraceManager.getTrace(MidpointAuthFilter.class);
private final Map<Class<?>, Object> sharedObjects;
@Autowired
private ObjectPostProcessor<Object> objectObjectPostProcessor;
@Autowired
private SystemObjectCache systemObjectCache;
@Autowired
private AuthModuleRegistryImpl authModuleRegistry;
@Autowired
private AuthChannelRegistryImpl authChannelRegistry;
@Autowired
private MidpointAuthenticationManager authenticationManager;
@Autowired
private PrismContext prismContext;
private AuthenticationsPolicyType authenticationPolicy;
private PreLogoutFilter preLogoutFilter = new PreLogoutFilter();
public MidpointAuthFilter(Map<Class<? extends Object>, Object> sharedObjects) {
this.sharedObjects = sharedObjects;
}
public PreLogoutFilter getPreLogoutFilter() {
return preLogoutFilter;
}
public void createFilterForAuthenticatedRequest() {
ModuleWebSecurityConfig module = objectObjectPostProcessor.postProcess(new ModuleWebSecurityConfig(null));
module.setObjectPostProcessor(objectObjectPostProcessor);
}
public AuthenticationsPolicyType getDefaultAuthenticationPolicy() throws SchemaException {
if (authenticationPolicy == null) {
authenticationPolicy = SecurityPolicyUtil.createDefaultAuthenticationPolicy(prismContext.getSchemaRegistry());
}
return authenticationPolicy;
}
@Override
public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
doFilterInternal(request, response, chain);
}
private void doFilterInternal(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
HttpServletRequest httpRequest = (HttpServletRequest) request;
//request for permit all page (for example errors and login pages)
if (SecurityUtils.isPermitAll(httpRequest) && !SecurityUtils.isLoginPage(httpRequest)) {
chain.doFilter(request, response);
return;
}
MidpointAuthentication mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication();
AuthenticationsPolicyType authenticationsPolicy;
CredentialsPolicyType credentialsPolicy = null;
PrismObject<SecurityPolicyType> authPolicy = null;
try {
authPolicy = getSecurityPolicy();
authenticationsPolicy = getAuthenticationPolicy(authPolicy);
if (authPolicy != null) {
credentialsPolicy = authPolicy.asObjectable().getCredentials();
}
} catch (SchemaException e) {
LOGGER.error("Couldn't load Authentication policy", e);
try {
authenticationsPolicy = getDefaultAuthenticationPolicy();
} catch (SchemaException schemaException) {
LOGGER.error("Couldn't get default authentication policy");
throw new IllegalArgumentException("Couldn't get default authentication policy", e);
}
}
//is path for which is ignored authentication
if (SecurityUtils.isIgnoredLocalPath(authenticationsPolicy, httpRequest)) {
chain.doFilter(request, response);
return;
}
AuthenticationSequenceType sequence = getAuthenticationSequence(mpAuthentication, httpRequest, authenticationsPolicy);
if (sequence == null) {
throw new IllegalArgumentException("Couldn't find sequence for URI '" + httpRequest.getRequestURI() + "' in authentication of Security Policy with oid " + authPolicy.getOid());
}
//change generic logout path to logout path for actual module
getPreLogoutFilter().doFilter(request, response);
AuthenticationChannel authenticationChannel = SecurityUtils.buildAuthChannel(authChannelRegistry, sequence);
List<AuthModule> authModules = createAuthenticationModuleBySequence(mpAuthentication, sequence, httpRequest, authenticationsPolicy.getModules()
,authenticationChannel, credentialsPolicy);
//authenticated request
if (mpAuthentication != null && mpAuthentication.isAuthenticated() && sequence.equals(mpAuthentication.getSequence())) {
processingOfAuthenticatedRequest(mpAuthentication, httpRequest, response, chain);
return;
}
//couldn't find authentication modules
if (authModules == null || authModules.size() == 0) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(UrlUtils.buildRequestUrl(httpRequest)
+ "has no filters");
}
throw new AuthenticationServiceException("Couldn't find filters for sequence " + sequence.getName());
}
int indexOfProcessingModule = getIndexOfActualProcessingModule(mpAuthentication, httpRequest);
resolveErrorWithMoreModules(mpAuthentication, httpRequest);
if (needRestartAuthFlow(indexOfProcessingModule)) {
indexOfProcessingModule = restartAuthFlow(mpAuthentication, httpRequest, sequence, authModules);
mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication();
}
if (mpAuthentication.getAuthenticationChannel() == null) {
mpAuthentication.setAuthenticationChannel(authenticationChannel);
}
MidpointAuthFilter.VirtualFilterChain vfc = new MidpointAuthFilter.VirtualFilterChain(httpRequest, chain, authModules.get(indexOfProcessingModule).getSecurityFilterChain().getFilters());
vfc.doFilter(httpRequest, response);
}
private boolean needRestartAuthFlow(int indexOfProcessingModule) {
// if index == -1 indicate restart authentication flow
return indexOfProcessingModule == -1;
}
private int restartAuthFlow(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest, AuthenticationSequenceType sequence, List<AuthModule> authModules) {
SecurityContextHolder.getContext().setAuthentication(null);
SecurityContextHolder.getContext().setAuthentication(new MidpointAuthentication(sequence));
mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication();
mpAuthentication.setAuthModules(authModules);
mpAuthentication.setSessionId(httpRequest.getSession().getId());
mpAuthentication.addAuthentications(authModules.get(0).getBaseModuleAuthentication());
return mpAuthentication.resolveParallelModules(httpRequest, 0);
}
private void resolveErrorWithMoreModules(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest) {
//authentication flow fail and exist more as one authentication module write error
if (mpAuthentication != null && mpAuthentication.isAuthenticationFailed() && mpAuthentication.getAuthModules().size() > 1) {
Exception actualException = (Exception) httpRequest.getSession().getAttribute(WebAttributes.AUTHENTICATION_EXCEPTION);
String actualMessage;
String restartFlowMessage = "web.security.flexAuth.restart.flow";
if (actualException != null && StringUtils.isNotBlank(actualException.getMessage())) {
actualMessage = actualException.getMessage() + ";" + restartFlowMessage;
} else {
actualMessage = restartFlowMessage;
}
AuthenticationException exception = new AuthenticationServiceException(actualMessage);
SecurityUtils.saveException(httpRequest, exception);
}
}
private int getIndexOfActualProcessingModule(MidpointAuthentication mpAuthentication, HttpServletRequest request) {
int indexOfProcessingModule = -1;
// if exist authentication (authentication flow is processed) find actual processing module
if (SecurityContextHolder.getContext().getAuthentication() != null) {
indexOfProcessingModule = mpAuthentication.getIndexOfProcessingModule(true);
indexOfProcessingModule = mpAuthentication.resolveParallelModules(request, indexOfProcessingModule);
}
return indexOfProcessingModule;
}
private List<AuthModule> createAuthenticationModuleBySequence(MidpointAuthentication mpAuthentication, AuthenticationSequenceType sequence,
HttpServletRequest httpRequest, AuthenticationModulesType modules, AuthenticationChannel authenticationChannel, CredentialsPolicyType credentialsPolicy) {
List<AuthModule> authModules;
//change sequence of authentication during another sequence
if (mpAuthentication == null || !sequence.equals(mpAuthentication.getSequence())) {
SecurityContextHolder.getContext().setAuthentication(null);
authenticationManager.getProviders().clear();
authModules = SecurityUtils.buildModuleFilters(authModuleRegistry, sequence, httpRequest, modules,
credentialsPolicy, sharedObjects, authenticationChannel);
} else {
authModules = mpAuthentication.getAuthModules();
}
return authModules;
}
private AuthenticationSequenceType getAuthenticationSequence(MidpointAuthentication mpAuthentication, HttpServletRequest httpRequest, AuthenticationsPolicyType authenticationsPolicy) {
AuthenticationSequenceType sequence;
// permitAll pages (login, select ID for saml ...) during processing of modules
if (mpAuthentication != null && SecurityUtils.isLoginPage(httpRequest)) {
sequence = mpAuthentication.getSequence();
} else {
sequence = SecurityUtils.getSequenceByPath(httpRequest, authenticationsPolicy);
}
// use same sequence if focus is authenticated and channel id of new sequence is same
if (mpAuthentication != null && !mpAuthentication.getSequence().equals(sequence) && mpAuthentication.isAuthenticated()
&& (((sequence != null && sequence.getChannel() != null && mpAuthentication.getAuthenticationChannel().matchChannel(sequence)))
|| mpAuthentication.getAuthenticationChannel().getChannelId().equals(SecurityUtils.findChannelByRequest(httpRequest)))) {
//change logout path to new sequence
if (SecurityUtils.isBasePathForSequence(httpRequest, sequence)) {
mpAuthentication.getAuthenticationChannel().setPathAfterLogout(httpRequest.getServletPath());
ModuleAuthentication authenticatedModule = SecurityUtils.getAuthenticatedModule();
authenticatedModule.setInternalLogout(true);
}
sequence = mpAuthentication.getSequence();
}
return sequence;
}
private AuthenticationsPolicyType getAuthenticationPolicy(PrismObject<SecurityPolicyType> authPolicy) throws SchemaException {
//security policy without authentication
AuthenticationsPolicyType authenticationsPolicy;
if (authPolicy == null || authPolicy.asObjectable().getAuthentication() == null
|| authPolicy.asObjectable().getAuthentication().getSequence() == null
|| authPolicy.asObjectable().getAuthentication().getSequence().isEmpty()) {
authenticationsPolicy = getDefaultAuthenticationPolicy();
} else {
authenticationsPolicy = authPolicy.asObjectable().getAuthentication();
}
return authenticationsPolicy;
}
private PrismObject<SecurityPolicyType> getSecurityPolicy() throws SchemaException {
return systemObjectCache.getSecurityPolicy(new OperationResult("load security policy"));
}
private void processingOfAuthenticatedRequest(MidpointAuthentication mpAuthentication, ServletRequest httpRequest, ServletResponse response, FilterChain chain) throws IOException, ServletException {
for (ModuleAuthentication moduleAuthentication : mpAuthentication.getAuthentications()) {
if (StateOfModule.SUCCESSFULLY.equals(moduleAuthentication.getState())) {
int i = mpAuthentication.getIndexOfModule(moduleAuthentication);
MidpointAuthFilter.VirtualFilterChain vfc = new MidpointAuthFilter.VirtualFilterChain(httpRequest, chain,
mpAuthentication.getAuthModules().get(i).getSecurityFilterChain().getFilters());
vfc.doFilter(httpRequest, response);
}
}
}
private static class VirtualFilterChain implements FilterChain {
private final FilterChain originalChain;
private final List<Filter> additionalFilters;
private final int size;
private int currentPosition = 0;
private VirtualFilterChain(ServletRequest firewalledRequest,
FilterChain chain, List<Filter> additionalFilters) {
this.originalChain = chain;
this.additionalFilters = additionalFilters;
this.size = additionalFilters.size();
}
@Override
public void doFilter(ServletRequest request, ServletResponse response)
throws IOException, ServletException {
if (currentPosition == size) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(UrlUtils.buildRequestUrl((HttpServletRequest) request)
+ " reached end of additional filter chain; proceeding with original chain, if url is permit all");
}
// MidpointAuthentication mpAuthentication = (MidpointAuthentication) SecurityContextHolder.getContext().getAuthentication();
// //authentication pages (login, select ID for saml ...) during processing of modules
// if (AuthUtil.isPermitAll((HttpServletRequest) request) && mpAuthentication != null && mpAuthentication.isProcessing()) {
// originalChain.doFilter(request, response);
// return;
// }
originalChain.doFilter(request, response);
}
else {
currentPosition++;
Filter nextFilter = additionalFilters.get(currentPosition - 1);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(UrlUtils.buildRequestUrl((HttpServletRequest) request)
+ " at position " + currentPosition + " of " + size
+ " in additional filter chain; firing Filter: '"
+ nextFilter.getClass().getSimpleName() + "'");
}
nextFilter.doFilter(request, response, this);
}
}
}
public interface FilterChainValidator {
void validate(MidpointAuthFilter filterChainProxy);
}
private static class NullFilterChainValidator implements MidpointAuthFilter.FilterChainValidator {
@Override
public void validate(MidpointAuthFilter filterChainProxy) {
}
}
}