@@ -320,37 +320,54 @@ def post_entity_categories(maps, **kwargs):
320320class Policy (object ):
321321 """ handles restrictions on assertions """
322322
323- def __init__ (self , restrictions = None ):
324- if restrictions :
325- self .compile (restrictions )
326- else :
327- self ._restrictions = None
323+ def __init__ (self , restrictions = None , config = None ):
324+ self ._config = config
325+ self ._restrictions = self .setup_restrictions (restrictions )
326+ logger .debug ("policy restrictions: %s" , self ._restrictions )
328327 self .acs = []
329328
330- def compile (self , restrictions ):
329+ def setup_restrictions (self , restrictions = None ):
330+ if restrictions is None :
331+ return None
332+
333+ restrictions = copy .deepcopy (restrictions )
334+ # TODO: Split policy config in service_providers and registration_authorities
335+ # "policy": {
336+ # "service_providers": {
337+ # "default": ...,
338+ # "urn:mace:example.com:saml:roland:sp": ...,
339+ # },
340+ # "registration_authorities": {
341+ # "default": ...,
342+ # "http://www.swamid.se": ...,
343+ # },
344+ # },
345+ registration_authorities = restrictions .pop ('registration_authorities' , None )
346+ restrictions = self .compile (restrictions )
347+ if registration_authorities :
348+ restrictions ['registration_authorities' ] = self .compile (registration_authorities )
349+ return restrictions
350+
351+ @staticmethod
352+ def compile (restrictions ):
331353 """ This is only for IdPs or AAs, and it's about limiting what
332354 is returned to the SP.
333355 In the configuration file, restrictions on which values that
334356 can be returned are specified with the help of regular expressions.
335357 This function goes through and pre-compiles the regular expressions.
336358
337- :param restrictions:
359+ :param restrictions: policy configuration
338360 :return: The assertion with the string specification replaced with
339361 a compiled regular expression.
340362 """
341-
342- self ._restrictions = copy .deepcopy (restrictions )
343-
344- for who , spec in self ._restrictions .items ():
363+ for who , spec in restrictions .items ():
345364 if spec is None :
346365 continue
347- try :
348- items = spec ["entity_categories" ]
349- except KeyError :
350- pass
351- else :
366+
367+ entity_categories = spec .get ("entity_categories" )
368+ if entity_categories is not None :
352369 ecs = []
353- for cat in items :
370+ for cat in entity_categories :
354371 try :
355372 _mod = importlib .import_module (cat )
356373 except ImportError :
@@ -366,25 +383,27 @@ def compile(self, restrictions):
366383 _ec [key ] = (alist , _only_required )
367384 ecs .append (_ec )
368385 spec ["entity_categories" ] = ecs
369- try :
370- restr = spec ["attribute_restrictions" ]
371- except KeyError :
372- continue
373386
374- if restr is None :
387+ attribute_restrictions = spec .get ("attribute_restrictions" )
388+ if attribute_restrictions is None :
375389 continue
376390
377- _are = {}
378- for key , values in restr .items ():
391+ _attribute_restrictions = {}
392+ for key , values in attribute_restrictions .items ():
379393 if not values :
380- _are [key .lower ()] = None
394+ _attribute_restrictions [key .lower ()] = None
381395 continue
396+ _attribute_restrictions [key .lower ()] = [re .compile (value ) for value in values ]
382397
383- _are [key .lower ()] = [re .compile (value ) for value in values ]
384- spec ["attribute_restrictions" ] = _are
385- logger .debug ("policy restrictions: %s" , self ._restrictions )
398+ spec ["attribute_restrictions" ] = _attribute_restrictions
386399
387- return self ._restrictions
400+ return restrictions
401+
402+ def _lookup_registry_authority (self , sp_entity_id ):
403+ if self ._config and self ._config .metadata :
404+ registration_info = self ._config .metadata .registration_info (sp_entity_id )
405+ return registration_info .get ('registration_authority' )
406+ return None
388407
389408 def get (self , attribute , sp_entity_id , default = None , post_func = None ,
390409 ** kwargs ):
@@ -399,16 +418,22 @@ def get(self, attribute, sp_entity_id, default=None, post_func=None,
399418 if not self ._restrictions :
400419 return default
401420
402- try :
403- try :
404- val = self ._restrictions [sp_entity_id ][attribute ]
405- except KeyError :
406- try :
407- val = self ._restrictions ["default" ][attribute ]
408- except KeyError :
409- val = None
410- except KeyError :
411- val = None
421+ registration_authority_name = self ._lookup_registry_authority (sp_entity_id )
422+ registration_authorities = self ._restrictions .get ("registration_authorities" )
423+
424+ val = None
425+ # Specific SP takes precedence
426+ if sp_entity_id in self ._restrictions :
427+ val = self ._restrictions [sp_entity_id ].get (attribute )
428+ # Second choice is if the SP is part of a configured registration authority
429+ elif registration_authorities and registration_authority_name in registration_authorities :
430+ val = registration_authorities [registration_authority_name ].get (attribute )
431+ # Third is to try default for registration authorities
432+ elif registration_authorities and 'default' in registration_authorities :
433+ val = registration_authorities ['default' ].get (attribute )
434+ # Lastly we try default for SPs
435+ elif 'default' in self ._restrictions :
436+ val = self ._restrictions .get ('default' ).get (attribute )
412437
413438 if val is None :
414439 return default
@@ -422,16 +447,15 @@ def get_nameid_format(self, sp_entity_id):
422447 :param: The SP entity ID
423448 :retur: The format
424449 """
425- return self .get ("nameid_format" , sp_entity_id ,
426- saml .NAMEID_FORMAT_TRANSIENT )
450+ return self .get ("nameid_format" , sp_entity_id , saml .NAMEID_FORMAT_TRANSIENT )
427451
428452 def get_name_form (self , sp_entity_id ):
429453 """ Get the NameFormat to used for the entity id
430454 :param: The SP entity ID
431455 :retur: The format
432456 """
433457
434- return self .get ("name_form" , sp_entity_id , NAME_FORMAT_URI )
458+ return self .get ("name_form" , sp_entity_id , default = NAME_FORMAT_URI )
435459
436460 def get_lifetime (self , sp_entity_id ):
437461 """ The lifetime of the assertion
@@ -458,32 +482,20 @@ def get_fail_on_missing_requested(self, sp_entity_id):
458482 :return: The restrictions
459483 """
460484
461- return self .get ("fail_on_missing_requested" , sp_entity_id , True )
462-
463- def entity_category_attributes (self , ec ):
464- if not self ._restrictions :
465- return None
466-
467- ec_maps = self ._restrictions ["default" ]["entity_categories" ]
468- for ec_map in ec_maps :
469- try :
470- return ec_map [ec ]
471- except KeyError :
472- pass
473- return []
485+ return self .get ("fail_on_missing_requested" , sp_entity_id , default = True )
474486
475487 def get_entity_categories (self , sp_entity_id , mds , required ):
476488 """
477489
478490 :param sp_entity_id:
479491 :param mds: MetadataStore instance
492+ :param required: required attributes
480493 :return: A dictionary with restrictions
481494 """
482495
483496 kwargs = {"mds" : mds , 'required' : required }
484497
485- return self .get ("entity_categories" , sp_entity_id , default = {},
486- post_func = post_entity_categories , ** kwargs )
498+ return self .get ("entity_categories" , sp_entity_id , default = {}, post_func = post_entity_categories , ** kwargs )
487499
488500 def not_on_or_after (self , sp_entity_id ):
489501 """ When the assertion stops being valid, should not be
@@ -495,6 +507,17 @@ def not_on_or_after(self, sp_entity_id):
495507
496508 return in_a_while (** self .get_lifetime (sp_entity_id ))
497509
510+ def get_sign (self , sp_entity_id ):
511+ """
512+ Possible choices
513+ "sign": ["response", "assertion", "on_demand"]
514+
515+ :param sp_entity_id:
516+ :return:
517+ """
518+
519+ return self .get ("sign" , sp_entity_id , default = [])
520+
498521 def filter (self , ava , sp_entity_id , mdstore , required = None , optional = None ):
499522 """ What attribute and attribute values returns depends on what
500523 the SP has said it wants in the request or in the metadata file and
@@ -568,16 +591,18 @@ def conditions(self, sp_entity_id):
568591 audience = [factory (saml .Audience ,
569592 text = sp_entity_id )])])
570593
571- def get_sign (self , sp_entity_id ):
572- """
573- Possible choices
574- "sign": ["response", "assertion", "on_demand"]
575-
576- :param sp_entity_id:
577- :return:
578- """
594+ def entity_category_attributes (self , ec ):
595+ # TODO: Not used. Remove?
596+ if not self ._restrictions :
597+ return None
579598
580- return self .get ("sign" , sp_entity_id , [])
599+ ec_maps = self ._restrictions ["default" ]["entity_categories" ]
600+ for ec_map in ec_maps :
601+ try :
602+ return ec_map [ec ]
603+ except KeyError :
604+ pass
605+ return []
581606
582607
583608class EntityCategories (object ):
0 commit comments