Skip to content

Commit

Permalink
Merge pull request #224 from IdentityPython/ft-more_typing
Browse files Browse the repository at this point in the history
more typing
  • Loading branch information
leifj committed Apr 22, 2021
2 parents c454257 + d5f1903 commit 052c556
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 193 deletions.
129 changes: 74 additions & 55 deletions src/pyff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
from datetime import datetime, timedelta
from json import dumps
from typing import Any, Iterable, List, Mapping
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple

import pkg_resources
import pyramid.httpexceptions as exc
Expand All @@ -13,6 +13,7 @@
from lxml import etree
from pyramid.config import Configurator
from pyramid.events import NewRequest
from pyramid.request import Request
from pyramid.response import Response
from six import b
from six.moves.urllib_parse import quote_plus
Expand All @@ -22,27 +23,29 @@
from pyff.logs import get_log
from pyff.pipes import plumbing
from pyff.repo import MDRepository
from pyff.resource import Resource, ResourceInfo
from pyff.resource import Resource
from pyff.samlmd import entity_display_name
from pyff.utils import b2u, dumptree, duration2timedelta, hash_id, json_serializer, utc_now
from pyff.utils import b2u, dumptree, hash_id, json_serializer, utc_now

log = get_log(__name__)


class NoCache(object):
def __init__(self):
""" Dummy implementation for when caching isn't enabled """

def __init__(self) -> None:
pass

def __getitem__(self, item):
def __getitem__(self, item: Any) -> None:
return None

def __setitem__(self, instance, value):
def __setitem__(self, instance: Any, value: Any) -> Any:
return value


def robots_handler(request):
def robots_handler(request: Request) -> Response:
"""
Impelements robots.txt
Implements robots.txt
:param request: the HTTP request
:return: robots.txt
Expand All @@ -55,7 +58,7 @@ def robots_handler(request):
)


def status_handler(request):
def status_handler(request: Request) -> Response:
"""
Implements the /api/status endpoint
Expand All @@ -80,34 +83,38 @@ def status_handler(request):


class MediaAccept(object):
def __init__(self, accept):
def __init__(self, accept: str):
self._type = AcceptableType(accept)

def has_key(self, key):
def has_key(self, key: Any) -> bool: # Literal[True]:
return True

def get(self, item):
def get(self, item: Any) -> Any:
return self._type.matches(item)

def __contains__(self, item):
def __contains__(self, item: Any) -> Any:
return self._type.matches(item)

def __str__(self):
def __str__(self) -> str:
return str(self._type)


xml_types = ('text/xml', 'application/xml', 'application/samlmetadata+xml')


def _is_xml_type(accepter):
def _is_xml_type(accepter: MediaAccept) -> bool:
return any([x in accepter for x in xml_types])


def _is_xml(data):
def _is_xml(data: Any) -> bool:
return isinstance(data, (etree._Element, etree._ElementTree))


def _fmt(data, accepter):
def _fmt(data: Any, accepter: MediaAccept) -> Tuple[str, str]:
"""
Format data according to the accepted content type of the requester.
Return data as string (either XML or json) and a content-type.
"""
if data is None or len(data) == 0:
return "", 'text/plain'
if _is_xml(data) and _is_xml_type(accepter):
Expand All @@ -127,7 +134,7 @@ def call(entry: str) -> None:
return None


def request_handler(request):
def request_handler(request: Request) -> Response:
"""
The main GET request handler for pyFF. Implements caching and forwards the request to process_handler
Expand All @@ -146,7 +153,7 @@ def request_handler(request):
return r


def process_handler(request):
def process_handler(request: Request) -> Response:
"""
The main request handler for pyFF. Implements API call hooks and content negotiation.
Expand All @@ -155,7 +162,8 @@ def process_handler(request):
"""
_ctypes = {'xml': 'application/samlmetadata+xml;application/xml;text/xml', 'json': 'application/json'}

def _d(x, do_split=True):
def _d(x: Optional[str], do_split: bool = True) -> Tuple[Optional[str], Optional[str]]:
""" Split a path into a base component and an extension. """
if x is not None:
x = x.strip()

Expand All @@ -170,7 +178,7 @@ def _d(x, do_split=True):

return x, None

log.debug(request)
log.debug(f'Processing request: {request}')

if request.matchdict is None:
raise exc.exception_response(400)
Expand All @@ -182,18 +190,18 @@ def _d(x, do_split=True):
pass

entry = request.matchdict.get('entry', 'request')
path = list(request.matchdict.get('path', []))
path_elem = list(request.matchdict.get('path', []))
match = request.params.get('q', request.params.get('query', None))

# Enable matching on scope.
match = match.split('@').pop() if match and not match.endswith('@') else match
log.debug("match={}".format(match))

if 0 == len(path):
path = ['entities']
if not path_elem:
path_elem = ['entities']

alias = path.pop(0)
path = '/'.join(path)
alias = path_elem.pop(0)
path = '/'.join(path_elem)

# Ugly workaround bc WSGI drops double-slashes.
path = path.replace(':/', '://')
Expand Down Expand Up @@ -226,23 +234,31 @@ def _d(x, do_split=True):
accept = str(request.accept).split(',')[0]
valid_accept = accept and not ('application/*' in accept or 'text/*' in accept or '*/*' in accept)

path_no_extension, extension = _d(path, True)
accept_from_extension = _ctypes.get(extension, accept)
new_path: Optional[str] = path
path_no_extension, extension = _d(new_path, True)
accept_from_extension = accept
if extension:
accept_from_extension = _ctypes.get(extension, accept)

if policy == 'extension':
path = path_no_extension
new_path = path_no_extension
if not valid_accept:
accept = accept_from_extension
elif policy == 'adaptive':
if not valid_accept:
path = path_no_extension
new_path = path_no_extension
accept = accept_from_extension

if pfx and path:
q = "{%s}%s" % (pfx, path)
path = "/%s/%s" % (alias, path)
if not accept:
log.warning('Could not determine accepted response type')
raise exc.exception_response(400)

q: Optional[str]
if pfx and new_path:
q = f'{{{pfx}}}{new_path}'
new_path = f'/{alias}/{new_path}'
else:
q = path
q = new_path

try:
accepter = MediaAccept(accept)
Expand All @@ -254,18 +270,19 @@ def _d(x, do_split=True):
'url': request.current_route_url(),
'select': q,
'match': match.lower() if match else match,
'path': path,
'path': new_path,
'stats': {},
}

r = p.process(request.registry.md, state=state, raise_exceptions=True, scheduler=request.registry.scheduler)
log.debug(r)
log.debug(f'Plumbing process result: {r}')
if r is None:
r = []

response = Response()
response.headers.update(state.get('headers', {}))
ctype = state.get('headers').get('Content-Type', None)
_headers = state.get('headers', {})
response.headers.update(_headers)
ctype = _headers.get('Content-Type', None)
if not ctype:
r, t = _fmt(r, accepter)
ctype = t
Expand All @@ -280,20 +297,20 @@ def _d(x, do_split=True):
import traceback

log.debug(traceback.format_exc())
log.warning(ex)
log.warning(f'Exception from processing pipeline: {ex}')
raise exc.exception_response(409)
except BaseException as ex:
import traceback

log.debug(traceback.format_exc())
log.error(ex)
log.error(f'Exception from processing pipeline: {ex}')
raise exc.exception_response(500)

if request.method == 'GET':
raise exc.exception_response(404)


def webfinger_handler(request):
def webfinger_handler(request: Request) -> Response:
"""An implementation the webfinger protocol
(http://tools.ietf.org/html/draft-ietf-appsawg-webfinger-12)
in order to provide information about up and downstream metadata available at
Expand Down Expand Up @@ -324,7 +341,7 @@ def webfinger_handler(request):
"subject": "http://reep.refeds.org:8080"
}
Depending on which version of pyFF your're running and the configuration you
Depending on which version of pyFF you're running and the configuration you
may also see downstream metadata listed using the 'role' attribute to the link
elements.
"""
Expand All @@ -335,11 +352,11 @@ def webfinger_handler(request):
if resource is None:
resource = request.host_url

jrd = dict()
dt = datetime.now() + duration2timedelta("PT1H")
jrd: Dict[str, Any] = dict()
dt = datetime.now() + timedelta(hours=1)
jrd['expires'] = dt.isoformat()
jrd['subject'] = request.host_url
links = list()
links: List[Dict[str, Any]] = list()
jrd['links'] = links

_dflt_rels = {
Expand All @@ -352,7 +369,7 @@ def webfinger_handler(request):
else:
rel = [rel]

def _links(url, title=None):
def _links(url: str, title: Any = None) -> None:
if url.startswith('/'):
url = url.lstrip('/')
for r in rel:
Expand Down Expand Up @@ -381,7 +398,7 @@ def _links(url, title=None):
return response


def resources_handler(request):
def resources_handler(request: Request) -> Response:
"""
Implements the /api/resources endpoint
Expand Down Expand Up @@ -409,7 +426,7 @@ def _info(r: Resource) -> Mapping[str, Any]:
return response


def pipeline_handler(request):
def pipeline_handler(request: Request) -> Response:
"""
Implements the /api/pipeline endpoint
Expand All @@ -422,7 +439,7 @@ def pipeline_handler(request):
return response


def search_handler(request):
def search_handler(request: Request) -> Response:
"""
Implements the /api/search endpoint
Expand All @@ -438,7 +455,7 @@ def search_handler(request):
log.debug("match={}".format(match))
store = request.registry.md.store

def _response():
def _response() -> Generator[bytes, bytes, None]:
yield b('[')
in_loop = False
entities = store.search(query=match.lower(), entity_filter=entity_filter)
Expand All @@ -454,8 +471,8 @@ def _response():
return response


def add_cors_headers_response_callback(event):
def cors_headers(request, response):
def add_cors_headers_response_callback(event: NewRequest) -> None:
def cors_headers(request: Request, response: Response) -> None:
response.headers.update(
{
'Access-Control-Allow-Origin': '*',
Expand All @@ -469,7 +486,7 @@ def cors_headers(request, response):
event.request.add_response_callback(cors_headers)


def launch_memory_usage_server(port=9002):
def launch_memory_usage_server(port: int = 9002) -> None:
import cherrypy
import dowser

Expand All @@ -479,7 +496,7 @@ def launch_memory_usage_server(port=9002):
cherrypy.engine.start()


def mkapp(*args, **kwargs):
def mkapp(*args: Any, **kwargs: Any) -> Any:
md = kwargs.pop('md', None)
if md is None:
md = MDRepository()
Expand All @@ -501,7 +518,9 @@ def mkapp(*args, **kwargs):
for mn in config.modules:
importlib.import_module(mn)

pipeline = args or None
pipeline = None
if args:
pipeline = list(args)
if pipeline is None and config.pipeline:
pipeline = [config.pipeline]

Expand Down

0 comments on commit 052c556

Please sign in to comment.