Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions astroquery/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def timeout(self, value):
else:
self._timeout = value

def request(self, session, cache_location=None, stream=False):
def request(self, session, cache_location=None, stream=False, auth=None):
return session.request(self.method, self.url, params=self.params,
data=self.data, headers=self.headers,
files=self.files, timeout=self.timeout,
stream=stream)
stream=stream, auth=auth)

def hash(self):
if self._hash is None:
Expand Down Expand Up @@ -105,7 +105,7 @@ def __call__(self, *args, **kwargs):

def _request(self, method, url, params=None, data=None, headers=None,
files=None, save=False, savedir='', timeout=None, cache=True,
stream=False):
stream=False, auth=None):
"""
A generic HTTP request method, similar to `requests.Session.request` but
with added caching-related tools
Expand All @@ -121,44 +121,50 @@ def _request(self, method, url, params=None, data=None, headers=None,
params : None or dict
data : None or dict
headers : None or dict
auth : None or dict
files : None or dict
See `requests.request`
save : bool
Whether to save the file to a local directory. Caching will happen independent of
this parameter if `BaseQuery.cache_location` is set, but the save location can be
overridden if ``save==True``
Whether to save the file to a local directory. Caching will happen
independent of this parameter if `BaseQuery.cache_location` is set,
but the save location can be overridden if ``save==True``
savedir : str
The location to save the local file if you want to save it
somewhere other than `BaseQuery.cache_location`
"""
if save:
local_filename = url.split('/')[-1]
local_filepath = os.path.join(self.cache_location or savedir or '.', local_filename)
local_filepath = os.path.join(self.cache_location or savedir or
'.', local_filename)
log.info("Downloading {0}...".format(local_filename))
self._download_file(url, local_filepath, timeout=timeout)
self._download_file(url, local_filepath, timeout=timeout,
auth=auth)
return local_filepath
else:
query = AstroQuery(method, url, params=params, data=data,
headers=headers, files=files, timeout=timeout)
if ((self.cache_location is None) or (not self._cache_active) or
(not cache)):
with suspend_cache(self):
response = query.request(self.__session, stream=stream)
response = query.request(self.__session, stream=stream,
auth=auth)
else:
response = query.from_cache(self.cache_location)
if not response:
response = query.request(self.__session,
self.cache_location,
stream=stream)
stream=stream,
auth=auth)
to_cache(response, query.request_file(self.cache_location))
return response

def _download_file(self, url, local_filepath, timeout=None):
def _download_file(self, url, local_filepath, timeout=None, auth=None):
"""
Download a file. Resembles `astropy.utils.data.download_file` but uses
the local ``__session``
"""
response = self.__session.get(url, timeout=timeout, stream=True)
response = self.__session.get(url, timeout=timeout, stream=True,
auth=auth)
if 'content-length' in response.headers:
length = int(response.headers['content-length'])
else:
Expand Down
3 changes: 2 additions & 1 deletion astroquery/utils/testing_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ class MockResponse(object):
"""

def __init__(self, content=None, url=None, headers={},
content_type=None, stream=False):
content_type=None, stream=False, auth=None):
assert content is None or hasattr(content, 'decode')
self.content = content
self.raw = content
self.headers = headers
if content_type is not None:
self.headers.update({'Content-Type':content_type})
self.url = url
self.auth = auth

def iter_lines(self):
c = self.text.split("\n")
Expand Down