Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion astroquery/eso/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def retrieve_data(self, datasets, cache=True):

Returns
-------
files : list of strings
files : list of strings or string
List of files that have been locally downloaded from the archive.

Examples
Expand All @@ -475,7 +475,10 @@ def retrieve_data(self, datasets, cache=True):
files = []

if isinstance(datasets, six.string_types):
return_list = False
datasets = [datasets]
else:
return_list = True
if not isinstance(datasets, (list, tuple, np.ndarray)):
raise TypeError("Datasets must be given as a list of strings.")

Expand Down Expand Up @@ -546,7 +549,10 @@ def retrieve_data(self, datasets, cache=True):
fileLink = "http://dataportal.eso.org/dataPortal"+fileId.attrs['value'].split()[1]
filename = self._request("GET", fileLink, save=True)
files.append(system_tools.gunzip(filename))
self._session.redirect_cache.clear() # EMpty the redirect cache of this request session
log.info("Done!")
if (not return_list) and (len(files)==1):
files = files[0]
return files

def verify_data_exists(self, dataset):
Expand Down
7 changes: 7 additions & 0 deletions astroquery/eso/tests/test_eso_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,10 @@ def test_retrieve_data(self):
result = eso.retrieve_data("MIDI.2014-07-25T02:03:11.561")
assert len(result)>0
assert "MIDI.2014-07-25T02:03:11.561" in result[0]

@pytest.mark.skipif('not Eso.USERNAME')
def test_retrieve_data_twice(self):
eso = Eso()
eso.login()
result1 = eso.retrieve_data("MIDI.2014-07-25T02:03:11.561")
result2 = eso.retrieve_data("AMBER.2006-03-14T07:40:19.830")
15 changes: 8 additions & 7 deletions astroquery/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


def to_cache(response, cache_file):
log.debug("Caching data to {0}".format(cache_file))
with open(cache_file, "wb") as f:
pickle.dump(response, f)

Expand Down Expand Up @@ -71,11 +72,9 @@ def hash(self):

def request_file(self, cache_location):
fn = os.path.join(cache_location, self.hash() + ".pickle")
log.debug("Request file is {0}".format(fn))
return fn

def from_cache(self, cache_location):
log.debug("Retrieving data from {0}".format(cache_location))
request_file = self.request_file(cache_location)
try:
with open(request_file, "rb") as f:
Expand All @@ -84,6 +83,8 @@ def from_cache(self, cache_location):
response = None
except:
response = None
if response:
log.debug("Retrieving data from {0}".format(request_file))
return response


Expand All @@ -95,7 +96,7 @@ class BaseQuery(object):
"""

def __init__(self):
self.__session = requests.session()
self._session = requests.session()
self.cache_location = os.path.join(paths.get_cache_dir(), 'astroquery',
self.__class__.__name__.split("Class")[0])
if not os.path.exists(self.cache_location):
Expand Down Expand Up @@ -157,12 +158,12 @@ def _request(self, method, url, params=None, data=None, headers=None,
if ((self.cache_location is None) or (not self._cache_active) or
(not cache)):
with suspend_cache(self):
response = query.request(self.__session, stream=stream,
response = query.request(self._session, stream=stream,
auth=auth)
else:
response = query.from_cache(self.cache_location)
if not response:
response = query.request(self.__session,
response = query.request(self._session,
self.cache_location,
stream=stream,
auth=auth)
Expand All @@ -172,9 +173,9 @@ def _request(self, method, url, params=None, data=None, headers=None,
def _download_file(self, url, local_filepath, timeout=None, auth=None):
"""
Download a file. Resembles `astropy.utils.data.download_file` but uses
the local ``__session``
the local ``_session``
"""
response = self.__session.get(url, timeout=timeout, stream=True,
response = self._session.get(url, timeout=timeout, stream=True,
auth=auth)
if 'content-length' in response.headers:
length = int(response.headers['content-length'])
Expand Down