Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model registry cleanup #230

Merged
merged 14 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 13 additions & 31 deletions doc/nb/FlavorTransformation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,10 @@
"from snewpy.models import Nakazato_2013\n",
"from snewpy.flavor_transformation import AdiabaticMSW, NonAdiabaticMSWH, \\\n",
" TwoFlavorDecoherence, ThreeFlavorDecoherence, \\\n",
" NeutrinoDecay, AdiabaticMSWes, NonAdiabaticMSWes\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mpl.rc('font', size=18)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"infile = '../../models/Nakazato_2013/nakazato-shen-z0.004-t_rev100ms-s20.0.fits'\n",
"model = Nakazato_2013(infile)"
" NeutrinoDecay, AdiabaticMSWes, NonAdiabaticMSWes\n",
"\n",
"mpl.rc('font', size=18)\n",
"%matplotlib inline"
]
},
{
Expand All @@ -53,6 +37,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Nakazato_2013(progenitor_mass=20*u.solMass, revival_time=100*u.ms, metallicity=0.004, eos='shen')\n",
"model"
]
},
Expand Down Expand Up @@ -249,8 +234,7 @@
" ax = axes[1][0]\n",
" ax.set(ylabel=r'flux [$10^{16}$ erg$^{-1}$ cm$^{-2}$ s$^{-1}$]')\n",
" \n",
" return fig\n",
"\n"
" return fig"
]
},
{
Expand Down Expand Up @@ -284,18 +268,11 @@
"fig = plot_spectra(model, xf_nmo, xf_imo, 100*u.ms)\n",
"fig.savefig('spectra_adiabaticmswes.pdf')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.9.5 ('snews')",
"language": "python",
"name": "python3"
},
Expand All @@ -309,7 +286,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.9.5"
},
"vscode": {
"interpreter": {
"hash": "e2528887d751495e023d57d695389d9a04f4c4d2e5866aaf6dc03a1ed45c573e"
}
}
},
"nbformat": 4,
Expand Down
27 changes: 13 additions & 14 deletions doc/source/gettingstarted.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ Download Supernova Models

SNEWPY includes a large number of supernova models from different simulation groups.
Since these models have a size of several 100 MB, they are not included in the initial install.
Instead, after installing, run the following command to download models you want to use:
Instead, SNEWPY automatically loads these files the first time you use a model. By default,
they are downloaded to a hidden directory given by ``snewpy.model_path``.

Alternatively, you can run the following command to bulk download model files to the current directory:

.. code-block:: console

$ python -c 'import snewpy; snewpy.get_models()'

By default, they will be downloaded to a subdirectory named ``SNEWPY-models/<model_name>/`` in the current directory.

.. note::

Each model includes a README file with more information, usually including a reference to the corresponding publication
Expand All @@ -39,24 +40,24 @@ This example script shows how to use SNEWPY to compare the luminosity of two dif

.. code-block:: python

import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt

import snewpy
from snewpy.models import Nakazato_2013, Bollig_2016
from snewpy.models.ccsn import Nakazato_2013, Bollig_2016
from snewpy.neutrino import Flavor

mpl.rc('font', size=16)
%matplotlib inline

# Download a few model files we can plot
snewpy.get_models(models=["Nakazato_2013", "Bollig_2016"])

# Read data from downloaded files
nakazato = Nakazato_2013('SNEWPY_models/Nakazato_2013/nakazato-shen-z0.004-t_rev100ms-s20.0.fits')
bollig = Bollig_2016('SNEWPY_models/Bollig_2016/s27.0c') # This model has one file per flavor. Use common prefix, not full filename.
# Initialise two different models. This automatically downloads the required data files.
nakazato = Nakazato_2013(progenitor_mass=20*u.solMass, revival_time=100*u.ms, metallicity=0.004, eos='shen')
bollig = Bollig_2016(progenitor_mass=27*u.solMass)

# Plot luminosity of both models
fig, ax = plt.subplots(1, figsize=(10, 6))

for flavor in Flavor:
ax.plot(nakazato.time, nakazato.luminosity[flavor]/1e51, # Report luminosity in units foe/s
label=flavor.to_tex() + ' (Nakazato)',
Expand All @@ -75,14 +76,12 @@ This example script shows how to use SNEWPY to compare the luminosity of two dif
ax.grid()
ax.legend(loc='upper right', ncol=2, fontsize=18)

fig.tight_layout()

This will generate the following figure:

.. image:: luminosity-comparison.*


The SNEWPY repository contains many Jupyter notebooks in ``doc/nb/`` with sample code
The SNEWPY repository contains many Jupyter notebooks in ``models/<model-name>/`` or ``doc/nb/`` with sample code
showing different models or how to apply flavor transformations to the neutrino fluxes.

More advanced usage of SNEWPY requires SNOwGLoBES and is described in the following section.
222 changes: 48 additions & 174 deletions models/Bollig_2016/Bollig_2016.ipynb

Large diffs are not rendered by default.

282 changes: 51 additions & 231 deletions models/Fornax_2019/Fornax_2019.ipynb

Large diffs are not rendered by default.

278 changes: 67 additions & 211 deletions models/Fornax_2021/Fornax_2021.ipynb

Large diffs are not rendered by default.

255 changes: 47 additions & 208 deletions models/Kuroda_2020/Kuroda_2020.ipynb

Large diffs are not rendered by default.

151 changes: 68 additions & 83 deletions models/Nakazato_2013/Nakazato_2013.ipynb

Large diffs are not rendered by default.

126 changes: 59 additions & 67 deletions models/OConnor_2013/OConnor_2013.ipynb

Large diffs are not rendered by default.

164 changes: 47 additions & 117 deletions models/OConnor_2015/OConnor_2015.ipynb

Large diffs are not rendered by default.

355 changes: 43 additions & 312 deletions models/Sukhbold_2015/Sukhbold_2015.ipynb

Large diffs are not rendered by default.

241 changes: 51 additions & 190 deletions models/Tamborra_2014/Tamborra_2014.ipynb

Large diffs are not rendered by default.

164 changes: 46 additions & 118 deletions models/Walk_2018/Walk_2018.ipynb

Large diffs are not rendered by default.

164 changes: 46 additions & 118 deletions models/Walk_2019/Walk_2019.ipynb

Large diffs are not rendered by default.

499 changes: 58 additions & 441 deletions models/Warren_2020/Warren_2020.ipynb

Large diffs are not rendered by default.

162 changes: 52 additions & 110 deletions models/Zha_2021/Zha_2021.ipynb

Large diffs are not rendered by default.

15 changes: 6 additions & 9 deletions python/snewpy/_model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,6 @@ def load(self) -> Path:
self.check()
return self.path

def open(self, flags='r'):
""" Load and open the local file, return the file object"""
return open(self.load(), flags)


def from_zenodo(zenodo_id:str, model:str, filename:str, path:str=model_path):
"""Access files on Zenodo.
Expand Down Expand Up @@ -162,7 +158,7 @@ def from_github(release_version:str, model:str, filename:str, path:str=model_pat
remote = github_url)


def get_model_data(model:str, filename:str, path:str=model_path):
def get_model_data(model: str, filename: str, path: str = model_path) -> Path:
"""Access model data. Configuration for each model is in a YAML file
distributed with SNEWPY.

Expand All @@ -174,10 +170,10 @@ def get_model_data(model:str, filename:str, path:str=model_path):

Returns
-------
file : FileHandle object.
Path of downloaded file.
"""
if os.path.isabs(filename):
return FileHandle(path=Path(filename))
return Path(filename)

params = { 'model':model, 'filename':filename, 'path':path }

Expand All @@ -193,12 +189,13 @@ def get_model_data(model:str, filename:str, path:str=model_path):

if repo == 'github':
params['release_version'] = modconf['release_version']
return from_github(**params)
fh = from_github(**params)
elif repo == 'zenodo':
params['zenodo_id'] = modconf['zenodo_id']
return from_zenodo(**params)
fh = from_zenodo(**params)
else:
raise ValueError(f'Repository {repo} not recognized')
return fh.load()
else:
raise KeyError(f'No configuration for {model}')

64 changes: 62 additions & 2 deletions python/snewpy/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,69 @@

from . import base, ccsn ,presn
import logging
from warnings import warn

from snewpy import get_models, model_path
from . import ccsn, presn


def __getattr__(name):
if name in dir(ccsn):
warn(f"{__name__}.{name} is moved to {__name__}.ccsn.{name}", FutureWarning)
return getattr(ccsn, name)
raise AttributeError(f"module {__name__} has no attribute {name}")


def _init_model(model_name, download=True, download_dir=model_path, **user_param):
"""Attempts to retrieve instantiated SNEWPY model using model class name and model parameters.
If a model name is valid, but is not found and `download`=True, this function will attempt to download the model

Parameters
----------
model_name : str
Name of SNEWPY model to import, must exactly match the name of the corresponding model class
download : bool
Switch for attempting to download model data if the first load attempt failed due to a missing file.
download_dir : str
Local directory to download model files to.
user_param : varies
User-requested model parameters used to initialize the model, if one is found.
Error checking is performed during model initialization

Raises
------
ValueError
If the requested model_name does not match any SNEWPY models

See Also
--------
snewpy.models.ccsn
snewpy.models.presn

Example
-------
>>> from snewpy.models import _init_model; import astropy.units as u
>>> _init_model('Nakazato_2013', progenitor_mass=13*u.Msun, metallicity=0.004, revival_time=0*u.s, eos='shen')
Nakazato_2013 Model: nakazato-shen-BH-z0.004-s30.0.fits
Progenitor mass : 30.0 solMass
EOS : Shen
Metallicity : 0.004
Revival time : 0.0 ms

:meta private:
"""
if model_name in dir(ccsn):
module = ccsn
elif model_name in dir(presn):
module = presn
else:
raise ValueError(f"Unable to find model with name '{model_name}' in snewpy.models.ccsn or snewpy.models.presn")

try:
return getattr(module, model_name)(**user_param)
except FileNotFoundError as e:
logger = logging.getLogger()
logger.warning(f"Unable to find model {model_name} in {download_dir}")
if not download:
raise e
logger.warning(f"Attempting to download model...")
get_models(model_name, download_dir)
return getattr(module, model_name)(**user_param)
99 changes: 95 additions & 4 deletions python/snewpy/models/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import itertools as it
import os
from abc import ABC, abstractmethod
from warnings import warn

import numpy as np
from astropy import units as u
from astropy.table import Table, join
from astropy.units import UnitTypeError, get_physical_type
from astropy.units.quantity import Quantity
from scipy.special import loggamma
from snewpy import _model_downloader
Expand Down Expand Up @@ -353,10 +355,9 @@ def __init__(self, filename, eos='LS220', metadata={}):
# Open the requested filename using the model downloader.
datafile = _model_downloader.get_model_data(self.__class__.__name__, _filename)

with datafile.open():
simtab = Table.read(datafile.path,
names=['TIME', _lname, _ename, _e2name],
format='ascii')
simtab = Table.read(datafile,
names=['TIME', _lname, _ename, _e2name],
format='ascii')
simtab['TIME'].unit = 's'
simtab[_lname].unit = '1e51 erg/s'
simtab[_aname] = (2*simtab[_ename]**2 - simtab[_e2name]) / (simtab[_e2name] - simtab[_ename]**2)
Expand All @@ -374,3 +375,93 @@ def __init__(self, filename, eos='LS220', metadata={}):
super().__init__(simtab, metadata)


class _RegistryModel(ABC):
"""Base class for supernova model classes that initialise from physics parameters."""

_param_validator = None

@classmethod
def get_param_combinations(cls):
"""Returns all valid combinations of parameters for a given SNEWPY register model.

Subclasses can provide a Callable `cls._param_validator` that takes a combination of parameters
as an argument and returns True if a particular combinations of parameters is valid.
If None is provided, all combinations are considered valid.

Returns
-------
valid_combinations: tuple[dict]
A tuple of all valid parameter combinations stored as Dictionaries
"""
for key, val in cls.param.items():
if not isinstance(val, (list, Quantity)):
cls.param[key] = [val]
elif isinstance(val, Quantity) and val.size == 1:
try:
# check if val.value is iterable, e.g. a list or a NumPy array
iter(val.value)
except:
cls.param[key] = [val.value] * val.unit
combos = tuple(dict(zip(cls.param, combo)) for combo in it.product(*cls.param.values()))
return tuple(c for c in filter(cls._param_validator, combos))

def check_valid_params(cls, **user_params):
"""Checks that the model-specific values, units, names and conbinations of requested parameters are valid.

Parameters
----------
user_params : varies
User-requested model parameters to be tested for validity.
NOTE: This must be provided as kwargs that match the keys of cls.param

Raises
------
ValueError
If invalid model parameters are provided based on units, allowed values, etc.
UnitTypeError
If invalid units are provided for a model parameter

See Also
--------
snewpy.models.ccsn
snewpy.models.presn
"""
# Check that the appropriate number of params are provided
if not all(key in user_params for key in cls.param.keys()):
raise ValueError(f"Missing parameter! Expected {cls.param.keys()} but was given {user_params.keys()}")

# Check parameter units and values
for (key, allowed_params), user_param in zip(cls.param.items(), user_params.values()):

# If both have units, check that the user param value is valid. If valid, continue. Else, error
if type(user_param) == Quantity and type(allowed_params) == Quantity:
if get_physical_type(user_param.unit) != get_physical_type(allowed_params.unit):
raise UnitTypeError(f"Incorrect units {user_param.unit} provided for parameter {key}, "
f"expected {allowed_params.unit}")

elif np.isin(user_param.to(allowed_params.unit).value, allowed_params.value):
continue
else:
raise ValueError(f"Invalid value '{user_param}' provided for parameter {key}, "
f"allowed value(s): {allowed_params}")

# If one only one has units, then error
elif (type(user_param) == Quantity) ^ (type(allowed_params) == Quantity):
# User param has units, model param is unitless
if type(user_param) == Quantity:
raise ValueError(f"Invalid units {user_param.unit} for parameter {key} provided, expected None")
else:
raise ValueError(f"Missing units for parameter {key}, expected {allowed_params.unit}")

# Check that unitless user param value is valid. If valid, continue. Else, Error
elif user_param in allowed_params:
continue
else:
raise ValueError(f"Invalid value '{user_param}' provided for parameter {key}, "
f"allowed value(s): {allowed_params}")

# Check Combinations (Logic lives inside model subclasses under model.isvalid_param_combo)
if user_params not in cls.get_param_combinations():
raise ValueError(
f"Invalid parameter combination. See {cls.__class__.__name__}.get_param_combinations() for a "
"list of allowed parameter combinations.")
Loading