Skip to content

Commit

Permalink
Refactor session to not be a singleton
Browse files Browse the repository at this point in the history
Right now, session was a singleton, and there was only one of
them and all the calls were using it.  This is great from
simplicity point of view, but for auth, you need more context
to set the correct auth methods on each call.  This context
will end up being a part of the session, which also makes the
call.

This commit creates the session at the service level, and passes
it down into all the subclasses it creates, such as queries,
results, and records as a keyword parameter.
  • Loading branch information
cbanek committed Aug 20, 2019
1 parent 3d30757 commit 19e8965
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 85 deletions.
24 changes: 14 additions & 10 deletions pyvo/dal/adhoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from astropy.utils.collections import HomogeneousList

from ..utils.decorators import stream_decode_content
from ..utils.http import session


# monkeypatch astropy with group support in RESOURCE
Expand Down Expand Up @@ -89,8 +88,8 @@ class AdhocServiceResultsMixin:
"""
Mixing for adhoc:service functionallity for results classes.
"""
def __init__(self, votable, url=None):
super().__init__(votable, url=url)
def __init__(self, votable, url=None, session=None):
super().__init__(votable, url=url, session=session)

self._adhocservices = list(
resource for resource in votable.resources
Expand Down Expand Up @@ -178,7 +177,7 @@ def getdatalink(self):
def getdataset(self, timeout=None):
try:
url = next(self.getdatalink().bysemantics('#this')).access_url
response = session.get(url, stream=True, timeout=timeout)
response = self._session.get(url, stream=True, timeout=timeout)
response.raise_for_status()
return response.raw
except (DALServiceError, ValueError, StopIteration):
Expand Down Expand Up @@ -258,14 +257,17 @@ class DatalinkQuery(DALQuery):
:py:attr:`~pyvo.dal.query.DALQuery.baseurl` to send a configured
query to another service.
A session can also optionally be passed in that will be used for
network transactions made by this object to remote services.
In addition to the search constraint attributes described below, search
parameters can be set generically by name via dict semantics.
The typical function for submitting the query is ``execute()``; however,
alternate execute functions provide the response in different forms,
allowing the caller to take greater control of the result processing.
"""
@classmethod
def from_resource(cls, row, resource, **kwargs):
def from_resource(cls, row, resource, session=None, **kwargs):
"""
Creates a instance from a Record and a Datalink Resource.
Expand Down Expand Up @@ -313,10 +315,10 @@ def from_resource(cls, row, resource, **kwargs):
except KeyError:
query_params[name] = query_param

return cls(accessurl, **query_params)
return cls(accessurl, session=session, **query_params)

def __init__(
self, baseurl, id=None, responseformat=None, **keywords):
self, baseurl, id=None, responseformat=None, session=None, **keywords):
"""
initialize the query object with the given parameters
Expand All @@ -328,8 +330,10 @@ def __init__(
the dataset identifier
responseformat : str
the output format
session : object
optional session to use for network requests
"""
super().__init__(baseurl, **keywords)
super().__init__(baseurl, session=session, **keywords)

if id:
self["ID"] = id
Expand All @@ -350,7 +354,7 @@ def execute(self):
DALFormatError
for errors parsing the VOTable response
"""
return DatalinkResults(self.execute_votable(), url=self.queryurl)
return DatalinkResults(self.execute_votable(), url=self.queryurl, session=self._session)


class DatalinkResults(DatalinkResultsMixin, DALResults):
Expand Down Expand Up @@ -424,7 +428,7 @@ def getrecord(self, index):
--------
Record
"""
return DatalinkRecord(self, index)
return DatalinkRecord(self, index, session=self._session)

def bysemantics(self, semantics):
"""
Expand Down
7 changes: 5 additions & 2 deletions pyvo/dal/mimetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from astropy.io.fits import HDUList

from ..utils.http import session
from ..utils.http import use_session


mimetypes.add_type('application/fits', 'fits')
Expand Down Expand Up @@ -56,7 +56,7 @@ def mime2extension(mimetype, default=None):
return ext


def mime_object_maker(url, mimetype):
def mime_object_maker(url, mimetype, session=None):
"""
return a data object suitable for the mimetype given.
this will either return a astropy fits object or a pyvo DALResults object,
Expand All @@ -68,7 +68,10 @@ def mime_object_maker(url, mimetype):
the object download url
mimetype : str
the content mimetype
session : object
optional session to use for network requests
"""
session = use_session(session)
mimetype = mimeparse.parse_mime_type(mimetype)

if mimetype[0] == 'text':
Expand Down
46 changes: 31 additions & 15 deletions pyvo/dal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from .. import samp

from ..utils.decorators import stream_decode_content
from ..utils.http import session
from ..utils.http import use_session


class DALService:
Expand All @@ -54,16 +54,19 @@ class DALService:
endpoint.
"""

def __init__(self, baseurl):
def __init__(self, baseurl, session=None):
"""
instantiate the service connecting it to a base URL
Parameters
----------
baseurl : str
the base URL that should be used for forming queries to the service.
session : object
optional session to use for network requests
"""
self._baseurl = baseurl
self._session = use_session(session)

@property
def baseurl(self):
Expand Down Expand Up @@ -105,7 +108,7 @@ def create_query(self, **keywords):
DALQuery
a generic query object
"""
q = DALQuery(self.baseurl, **keywords)
q = DALQuery(self.baseurl, session=self._session, **keywords)
return q

def describe(self):
Expand All @@ -119,18 +122,22 @@ class DALQuery(dict):
functions will submit the query and return the results.
The base URL for the query can be changed via the baseurl property.
A session can also optionally be passed in that will be used for
network transactions made by this object to remote services.
"""

_ex = None

def __init__(self, baseurl, **keywords):
def __init__(self, baseurl, session=None, **keywords):
"""
initialize the query object with a baseurl
"""
if type(baseurl) == bytes:
baseurl = baseurl.decode("utf-8")

self._baseurl = baseurl.rstrip("?")
self._session = use_session(session)

self.update({key.upper(): value for key, value in keywords.items()})

Expand All @@ -156,7 +163,7 @@ def execute(self):
DALFormatError
for errors parsing the VOTable response
"""
return DALResults(self.execute_votable(), self.queryurl)
return DALResults(self.execute_votable(), self.queryurl, session=self._session)

def execute_raw(self):
"""
Expand Down Expand Up @@ -198,7 +205,7 @@ def submit(self):
url = self.queryurl
params = {k: v for k, v in self.items()}

response = session.get(url, params=params, stream=True)
response = self._session.get(url, params=params, stream=True)
return response

def execute_votable(self):
Expand Down Expand Up @@ -257,19 +264,23 @@ class DALResults:
"""
@classmethod
@stream_decode_content
def _from_result_url(cls, result_url):
def _from_result_url(cls, result_url, session):
return session.get(result_url, stream=True).raw

@classmethod
def from_result_url(cls, result_url):
def from_result_url(cls, result_url, session=None):
"""
Create a result object from a url.
Uses the optional session to make the request.
"""
session = use_session(session)
return cls(
votableparse(cls._from_result_url(result_url).read),
url=result_url)
votableparse(cls._from_result_url(result_url, session).read),
url=result_url,
session=session)

def __init__(self, votable, url=None):
def __init__(self, votable, url=None, session=None):
"""
initialize the cursor. This constructor is not typically called
by directly applications; rather an instance is obtained from calling
Expand All @@ -282,6 +293,8 @@ def __init__(self, votable, url=None):
astropy.io.votable.tree.VOTableFile instance.
url : str
the URL that produced the response
session : object
optional session to use for network requests
Raises
------
Expand All @@ -295,6 +308,8 @@ def __init__(self, votable, url=None):
self._votable = votable

self._url = url
self._session = use_session(session)

self._status = self._findstatus(votable)
if self._status[0].lower() not in ("ok", "overflow"):
raise DALQueryError(self._status[1], self._status[0], url)
Expand Down Expand Up @@ -525,7 +540,7 @@ def getrecord(self, index):
--------
Record
"""
return Record(self, index)
return Record(self, index, session=self._session)

def getvalue(self, name, index):
"""
Expand Down Expand Up @@ -611,9 +626,10 @@ class Record(Mapping):
additional functions for access to service type-specific data.
"""

def __init__(self, results, index):
def __init__(self, results, index, session=None):
self._results = results
self._index = index
self._session = use_session(session)
self._mapping = collections.OrderedDict(
zip(
results.fieldnames,
Expand Down Expand Up @@ -738,9 +754,9 @@ def getdataset(self, timeout=None):
raise KeyError("no dataset access URL recognized in record")

if timeout:
response = session.get(url, stream=True, timeout=timeout)
response = self._session.get(url, stream=True, timeout=timeout)
else:
response = session.get(url, stream=True)
response = self._session.get(url, stream=True)

response.raise_for_status()
return response.raw
Expand Down
18 changes: 11 additions & 7 deletions pyvo/dal/scs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,18 @@ class SCSService(DALService):
a representation of a Cone Search service
"""

def __init__(self, baseurl):
def __init__(self, baseurl, session=None):
"""
instantiate a Cone Search service
Parameters
----------
baseurl : str
the base URL for submitting search queries to the service.
session : object
optional session to use for network requests
"""
super().__init__(baseurl)
super().__init__(baseurl, session=session)

def _get_metadata(self):
"""
Expand Down Expand Up @@ -220,7 +222,7 @@ def create_query(self, pos=None, radius=None, verbosity=None, **keywords):
--------
SCSQuery
"""
return SCSQuery(self.baseurl, pos, radius, verbosity, **keywords)
return SCSQuery(self.baseurl, pos, radius, verbosity, session=self._session, **keywords)

def describe(self):
print(self.description)
Expand Down Expand Up @@ -272,7 +274,7 @@ class SCSQuery(DALQuery):
"""

def __init__(
self, baseurl, pos=None, radius=None, verbosity=None, **keywords):
self, baseurl, pos=None, radius=None, verbosity=None, session=None, **keywords):
"""
initialize the query object with a baseurl and the given parameters
Expand All @@ -292,8 +294,10 @@ def __init__(
to return in the result table. 0 means the minimum
set of columns, 3 means as many columns as are
available.
session : object
optional session to use for network requests
"""
super().__init__(baseurl)
super().__init__(baseurl, session=session)

if pos is not None:
self.pos = pos
Expand Down Expand Up @@ -406,7 +410,7 @@ def execute(self):
DALFormatError
for errors parsing the VOTable response
"""
return SCSResults(self.execute_votable(), url=self.queryurl)
return SCSResults(self.execute_votable(), url=self.queryurl, session=self._session)


class SCSResults(DALResults, DatalinkResultsMixin):
Expand Down Expand Up @@ -523,7 +527,7 @@ def getrecord(self, index):
--------
Record
"""
return SCSRecord(self, index)
return SCSRecord(self, index, session=self._session)


class SCSRecord(DatalinkRecordMixin, Record):
Expand Down
Loading

0 comments on commit 19e8965

Please sign in to comment.