Skip to content

Commit

Permalink
Merge pull request #325 from SNEWS2/JostMigenda/FixBulkDownloader
Browse files Browse the repository at this point in the history
Fix bulk downloader
  • Loading branch information
Sheshuk committed May 3, 2024
2 parents a2f4d36 + 59fb68b commit d897a4d
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 98 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@ jobs:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_JM }}
run: |
python -c 'import python.snewpy; python.snewpy._get_model_urls()'
cat python/snewpy/_model_urls.py
pip install .
python -c 'import snewpy; snewpy.get_models(models="Bollig_2016")'
rm -r SNEWPY_models/
python -c 'import snewpy; print(snewpy.__version__)'
python setup.py sdist bdist_wheel
twine upload dist/*
Expand Down
83 changes: 20 additions & 63 deletions python/snewpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

from ._version import __version__
from pathlib import Path
from sys import exit
import os

Expand All @@ -22,48 +21,45 @@
base_path = os.sep.join(src_path.split(os.sep)[:-2])
model_path = os.path.join(get_cache_dir(), 'snewpy/models')

def get_models(models=None, download_dir='SNEWPY_models'):
def get_models(models=None, download_dir=None):
"""Download model files from the snewpy repository.
Parameters
----------
models : list or str
Models to download. Can be 'all', name of a single model or list of model names.
download_dir : str
Local directory to download model files to.
[Deprecated, do not use.]
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
from ._model_urls import model_urls
from ._model_downloader import _download as download
from warnings import warn
from .models.registry_model import all_models

for model in list(model_urls):
if model_urls[model] == []:
del model_urls[model]
continue
if download_dir is not None:
warn("The `download_dir` argument to `get_models` is deprecated and will be removed soon.", FutureWarning, stacklevel=2)

all_models = {m.__name__: m for m in all_models}
all_model_names = sorted(all_models.keys())

if models == "all":
models = model_urls.keys()
models = all_model_names
elif isinstance(models, str):
models = [models]
elif models == None:
elif models is None:
# Select model(s) to download
print(f"Available models in this version of SNEWPY: {list(model_urls.keys())}")
if not model_urls:
print("Error: `get_models()` only works after installing SNEWPY via `pip install snewpy`. "
"If you have cloned the git repo, model files are available in the `models/` folder.")
return False
print(f"Available models in SNEWPY v{__version__}: {all_model_names}")

selected = input("\nType a model name, 'all' to download all models or <Enter> to cancel: ").strip()
if selected == "all":
models = model_urls.keys()
models = all_model_names
elif selected == "":
exit()
elif selected in model_urls.keys():
models = [selected]
elif selected in all_model_names:
models = {selected}
while True:
selected = input("\nType another model name or <Enter> if you have selected all models you want to download: ").strip()
if selected in model_urls.keys():
models.append(selected)
if selected in all_model_names:
models.add(selected)
elif selected == "":
break
else:
Expand All @@ -74,25 +70,12 @@ def get_models(models=None, download_dir='SNEWPY_models'):

print(f"\nYou have selected the models: {models}\n")

# Download model files
if not os.path.isdir(download_dir):
print(f"Creating directory '{download_dir}' ...")
os.makedirs(download_dir)

pool = ThreadPoolExecutor(max_workers=8)
results = []
print(f"Downloading files for {models} to '{model_path}' ...")
for model in models:
model_dir = download_dir + '/' + model
print(f"Downloading files for '{model}' into '{model_dir}' ...")

for url in model_urls[model]:
local_file = model_dir + url.split(model, maxsplit=1)[1]
if os.path.exists(local_file) and local_file.find('README') == -1 and local_file.find('.ipynb') == -1:
print(f"File '{local_file}' already exists. Skipping download.")
else:
if not os.path.isdir(os.path.dirname(local_file)):
os.makedirs(os.path.dirname(local_file))
results.append(pool.submit(download, src=url, dest=Path(local_file)))
for progenitor in all_models[model].get_param_combinations():
results.append(pool.submit(all_models[model], **progenitor))

exceptions = []
for result in as_completed(results):
Expand All @@ -103,29 +86,3 @@ def get_models(models=None, download_dir='SNEWPY_models'):
print("Please check your internet connection and try again later. If this persists, please report it at https://github.com/SNEWS2/snewpy/issues")
exit(1)
pool.shutdown(wait=False)


def _get_model_urls():
"""List URLs of model files for the current release.
When building a snewpy release, generate a dictionary of available models
and the URLs at which the respective files are located. Users can then use
get_models() to interactively select which model(s) to download.
"""

repo_dir = os.path.normpath(os.path.dirname(os.path.abspath(__file__)) + '/../../')
url_base = 'https://github.com/SNEWS2/snewpy/raw/v' + __version__

with open(os.path.dirname(os.path.abspath(__file__)) + '/_model_urls.py', 'w') as f:
f.write('model_urls = {\n')
for model in sorted(os.listdir(repo_dir + '/models')):
urls = []
for root, dirs, files in os.walk(repo_dir + '/models/' + model):
for file in files:
urls.append(f'{url_base}{root[len(repo_dir):]}/{file}')

f.write(f' "{model}": [\n')
for url in sorted(urls):
f.write(f' "{url}",\n')
f.write(' ],\n')
f.write('}\n')
6 changes: 0 additions & 6 deletions python/snewpy/_model_urls.py

This file was deleted.

20 changes: 2 additions & 18 deletions python/snewpy/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import logging
from warnings import warn

from snewpy import get_models, model_path
from . import ccsn, presn


Expand All @@ -12,18 +10,13 @@ def __getattr__(name):
raise AttributeError(f"module {__name__} has no attribute {name}")


def _init_model(model_name, download=True, download_dir=model_path, **user_param):
def _init_model(model_name, **user_param):
"""Attempts to retrieve instantiated SNEWPY model using model class name and model parameters.
If a model name is valid, but is not found and `download`=True, this function will attempt to download the model
Parameters
----------
model_name : str
Name of SNEWPY model to import, must exactly match the name of the corresponding model class
download : bool
Switch for attempting to download model data if the first load attempt failed due to a missing file.
download_dir : str
Local directory to download model files to.
user_param : varies
User-requested model parameters used to initialize the model, if one is found.
Error checking is performed during model initialization
Expand Down Expand Up @@ -57,13 +50,4 @@ def _init_model(model_name, download=True, download_dir=model_path, **user_param
else:
raise ValueError(f"Unable to find model with name '{model_name}' in snewpy.models.ccsn or snewpy.models.presn")

try:
return getattr(module, model_name)(**user_param)
except FileNotFoundError as e:
logger = logging.getLogger()
logger.warning(f"Unable to find model {model_name} in {download_dir}")
if not download:
raise e
logger.warning(f"Attempting to download model...")
get_models(model_name, download_dir)
return getattr(module, model_name)(**user_param)
return getattr(module, model_name)(**user_param)
7 changes: 0 additions & 7 deletions python/snewpy/test/test_02_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,13 @@
Sukhbold_2015, Bollig_2016, Walk_2018, \
Walk_2019, Fornax_2019, Warren_2020, \
Kuroda_2020, Fornax_2021, Zha_2021
from snewpy._model_urls import model_urls
from astropy import units as u
from snewpy import model_path
import os


class TestModels(unittest.TestCase):

def test_model_urls(self):
"""Test that snewpy._model_urls.model_urls is empty. This should be populated if snewpy is downloaded from PyPI.
This serves as a guard against accidentally committing/merging a populated model_urls to main.
"""
self.assertFalse(model_urls)

def test_Nakazato_2013(self):
"""
Instantiate a set of 'Nakazato 2013' models
Expand Down

0 comments on commit d897a4d

Please sign in to comment.