Skip to content

Commit

Permalink
Merge branch 'devel'. Bring 'main' to version 1.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobmoss committed Mar 2, 2023
2 parents 52c4059 + 1bffc1d commit bf8acc9
Show file tree
Hide file tree
Showing 171 changed files with 1,245 additions and 566 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -10,7 +10,7 @@ repos:
- id: check-ast
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 23.1.0
hooks:
- id: black
exclude: venv*/|dustmaps/|grids/
2 changes: 1 addition & 1 deletion basta/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.0"
__version__ = "1.2.1"
256 changes: 41 additions & 215 deletions basta/bastamain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from tqdm import tqdm

from basta import freq_fit, stats, process_output, priors, distances, plot_seismic
from basta import freq_fit, stats, process_output, priors, distances, plot_driver
from basta import utils_seismic as su
from basta import utils_general as util
from basta._version import __version__
Expand Down Expand Up @@ -264,6 +264,15 @@ def BASTA(
obsintervals,
) = su.prepare_obs(inputparams, verbose=verbose, debug=debug)

# Apply prior on dnufit to mimick the range defined by dnufrac
if fitfreqs["dnuprior"] and ("dnufit" not in limits):
dnufit_frac = fitfreqs["dnufrac"] * fitfreqs["dnufit"]
dnuerr = max(3 * fitfreqs["dnufit_err"], dnufit_frac)
limits["dnufit"] = [
fitfreqs["dnufit"] - dnuerr,
fitfreqs["dnufit"] + dnuerr,
]

# Check if grid is interpolated
try:
Grid["header/interpolation_time"][()]
Expand Down Expand Up @@ -348,6 +357,7 @@ def BASTA(

# Translate True/False to Yes/No
strmap = ("No", "Yes")
print(" - Automatic prior on dnu: {0}".format(strmap[fitfreqs["dnuprior"]]))
print(
" - Constraining lowest l = 0 (n = {0}) with f = {1:.3f} +/-".format(
obskey[1, 0], obs[0, 0]
Expand Down Expand Up @@ -383,7 +393,14 @@ def BASTA(
strmap[fitfreqs["threepoint"]]
)
)
print(" - Value of dnu: {0:.3f} microHz".format(fitfreqs["dnufit"]))
if fitfreqs["dnufit_err"]:
print(
" - Value of dnu: {0:.3f} +/- {1:.3f} microHz".format(
fitfreqs["dnufit"], fitfreqs["dnufit_err"]
)
)
else:
print(" - Value of dnu: {0:.3f} microHz".format(fitfreqs["dnufit"]))
print(" - Value of numax: {0:.3f} microHz".format(fitfreqs["numax"]))

weightcomment = ""
Expand Down Expand Up @@ -503,7 +520,7 @@ def BASTA(
trackcounter += len(group.items())

# Prepare the main loop
shapewarn = False
shapewarn = 0
warn = True
selectedmodels = {}
noofind = 0
Expand All @@ -527,7 +544,7 @@ def BASTA(

# For grid with interpolated tracks, skip tracks flagged as empty
if grid_is_intpol:
if libitem["FeHini_weight"][()] < 0:
if libitem["IntStatus"][()] < 0:
continue

# Check for diffusion
Expand Down Expand Up @@ -576,7 +593,6 @@ def BASTA(

# Check which models have phases as specified
if "phase" in inputparams:

# Mapping of verbose input phases to internal numbers
pmap = {
"pre-ms": 1,
Expand Down Expand Up @@ -754,7 +770,7 @@ def BASTA(
)

# Raise possible warnings
if shapewarn:
if shapewarn == 1:
print(
"Warning: Found models with fewer frequencies than observed!",
"These were set to zero likelihood!",
Expand All @@ -764,6 +780,11 @@ def BASTA(
"This is probably due to the interpolation scheme. Lookup",
"`interpolate_frequencies` for more details.",
)
if shapewarn == 2:
print(
"Warning: Models without frequencies overlapping with observed",
"ignored due to interpolation of ratios being impossible.",
)
if noofposind == 0:
fio.no_models(starid, inputparams, "No models found")
return
Expand All @@ -777,9 +798,9 @@ def BASTA(

# Find and print highest likelihood model info
maxPDF_path, maxPDF_ind = stats.get_highest_likelihood(
Grid, selectedmodels, outparams
Grid, selectedmodels, inputparams
)
stats.get_lowest_chi2(Grid, selectedmodels, outparams)
stats.get_lowest_chi2(Grid, selectedmodels, inputparams)

# Generate posteriors of ascii- and plotparams
# --> Print posteriors to console and log
Expand All @@ -798,219 +819,24 @@ def BASTA(
experimental=experimental,
validationmode=validationmode,
)

# Make frequency-related plots
freqplots = inputparams.get("freqplots")
if fitfreqs["active"] and len(freqplots):
# Check which plots to create
allfplots = freqplots[0] == True
if any(x == "allechelle" for x in freqplots):
freqplots += ["dupechelle", "echelle", "pairechelle"]
if any(x in freqtypes.rtypes for x in freqplots):
freqplots += ["ratios"]

# Naming of plots preparation
plotfmt = inputparams["plotfmt"]
plotfname = outfilename + "_{0}." + plotfmt

rawmaxmod = Grid[maxPDF_path + "/osc"][maxPDF_ind]
rawmaxmodkey = Grid[maxPDF_path + "/osckey"][maxPDF_ind]
maxmod = su.transform_obj_array(rawmaxmod)
maxmodkey = su.transform_obj_array(rawmaxmodkey)
maxmod = maxmod[:, maxmodkey[0, :] < 2.5]
maxmodkey = maxmodkey[:, maxmodkey[0, :] < 2.5]
maxjoins = freq_fit.calc_join(
mod=maxmod,
modkey=maxmodkey,
obs=obs,
plot_driver.plot_all_seismic(
freqplots,
Grid=Grid,
fitfreqs=fitfreqs,
obsfreqmeta=obsfreqmeta,
obsfreqdata=obsfreqdata,
obskey=obskey,
obs=obs,
obsintervals=obsintervals,
selectedmodels=selectedmodels,
path=maxPDF_path,
ind=maxPDF_ind,
plotfname=outfilename + "_{0}." + inputparams["plotfmt"],
debug=debug,
)
maxjoinkeys, maxjoin = maxjoins
maxmoddnu = Grid[maxPDF_path + "/dnufit"][maxPDF_ind]

if allfplots or "echelle" in freqplots:
plot_seismic.echelle(
selectedmodels,
Grid,
obs,
obskey,
mod=maxmod,
modkey=maxmodkey,
dnu=fitfreqs["dnufit"],
join=maxjoin,
joinkeys=maxjoinkeys,
pair=False,
duplicate=False,
output=plotfname.format("echelle_uncorrected"),
)
if allfplots or "pairechelle" in freqplots:
plot_seismic.echelle(
selectedmodels,
Grid,
obs,
obskey,
mod=maxmod,
modkey=maxmodkey,
dnu=fitfreqs["dnufit"],
join=maxjoin,
joinkeys=maxjoinkeys,
pair=True,
duplicate=False,
output=plotfname.format("pairechelle_uncorrected"),
)
if allfplots or "dupechelle" in freqplots:
plot_seismic.echelle(
selectedmodels,
Grid,
obs,
obskey,
mod=maxmod,
modkey=maxmodkey,
dnu=fitfreqs["dnufit"],
join=maxjoin,
joinkeys=maxjoinkeys,
duplicate=True,
pair=True,
output=plotfname.format("dupechelle_uncorrected"),
)

if fitfreqs["fcor"] == "None":
corjoin = maxjoin
coeffs = [1]
elif fitfreqs["fcor"] == "HK08":
corjoin, coeffs = freq_fit.HK08(
joinkeys=maxjoinkeys,
join=maxjoin,
nuref=fitfreqs["numax"],
bcor=fitfreqs["bexp"],
)
elif fitfreqs["fcor"] == "BG14":
corjoin, coeffs = freq_fit.BG14(
joinkeys=maxjoinkeys, join=maxjoin, scalnu=fitfreqs["numax"]
)
elif fitfreqs["fcor"] == "cubicBG14":
corjoin, coeffs = freq_fit.cubicBG14(
joinkeys=maxjoinkeys, join=maxjoin, scalnu=fitfreqs["numax"]
)

if len(coeffs) > 1:
print("The surface correction coefficients are", *coeffs)
else:
print("The surface correction coefficient is", *coeffs)

if allfplots or "echelle" in freqplots:
plot_seismic.echelle(
selectedmodels,
Grid,
obs,
obskey,
mod=maxmod,
modkey=maxmodkey,
dnu=fitfreqs["dnufit"],
join=corjoin,
joinkeys=maxjoinkeys,
freqcor=fitfreqs["fcor"],
coeffs=coeffs,
scalnu=fitfreqs["numax"],
pair=False,
duplicate=False,
output=plotfname.format("echelle"),
)
if allfplots or "pairechelle" in freqplots:
plot_seismic.echelle(
selectedmodels,
Grid,
obs,
obskey,
mod=maxmod,
modkey=maxmodkey,
dnu=fitfreqs["dnufit"],
join=corjoin,
joinkeys=maxjoinkeys,
freqcor=fitfreqs["fcor"],
coeffs=coeffs,
scalnu=fitfreqs["numax"],
pair=True,
duplicate=False,
output=plotfname.format("pairechelle"),
)
if allfplots or "dupechelle" in freqplots:
plot_seismic.echelle(
selectedmodels,
Grid,
obs,
obskey,
mod=maxmod,
modkey=maxmodkey,
dnu=fitfreqs["dnufit"],
join=corjoin,
joinkeys=maxjoinkeys,
freqcor=fitfreqs["fcor"],
coeffs=coeffs,
scalnu=fitfreqs["numax"],
duplicate=True,
pair=True,
output=plotfname.format("dupechelle"),
)
if "freqcormap" in freqplots or debug:
plot_seismic.correlation_map(
"freqs",
obsfreqdata,
plotfname.format("freqs_cormap"),
obskey=obskey,
)
if obsfreqmeta["getratios"]:
for ratseq in obsfreqmeta["ratios"]["plot"]:
ratnamestr = "ratios_{0}".format(ratseq)
plot_seismic.ratioplot(
obsfreqdata,
maxjoinkeys,
maxjoin,
maxmodkey,
maxmod,
ratseq,
output=plotfname.format(ratnamestr),
threepoint=fitfreqs["threepoint"],
interp_ratios=fitfreqs["interp_ratios"],
)
if fitfreqs["correlations"]:
plot_seismic.correlation_map(
ratseq,
obsfreqdata,
output=plotfname.format(ratnamestr + "_cormap"),
)

if obsfreqmeta["getepsdiff"]:
for epsseq in obsfreqmeta["epsdiff"]["plot"]:
epsnamestr = "epsdiff_{0}".format(epsseq)
plot_seismic.epsilon_difference_diagram(
mod=maxmod,
modkey=maxmodkey,
moddnu=maxmoddnu,
sequence=epsseq,
obsfreqdata=obsfreqdata,
output=plotfname.format(epsnamestr),
)
if fitfreqs["correlations"]:
plot_seismic.correlation_map(
epsseq,
obsfreqdata,
output=plotfname.format(epsnamestr + "_cormap"),
)
if obsfreqmeta["getepsdiff"] and debug:
if len(obsfreqmeta["epsdiff"]["plot"]) > 0:
plot_seismic.epsilon_difference_components_diagram(
mod=maxmod,
modkey=maxmodkey,
moddnu=maxmoddnu,
obs=obs,
obskey=obskey,
dnudata=obsfreqdata["freqs"]["dnudata"],
obsfreqdata=obsfreqdata,
obsfreqmeta=obsfreqmeta,
output=plotfname.format("DEBUG_epsdiff_components"),
)
else:
print(
"Did not get any frequency file input, skipping ratios and echelle plots."
Expand Down
22 changes: 21 additions & 1 deletion basta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ class freqtypes:
surfeffcorrs = ["HK08", "BG14", "cubicBG14"]


@dataclass
class statdata:
"""
Constant values for statistics, to ensure consistensy across code
Contains
--------
quantiles : list
Median, lower and upper percentiles of Bayesian posterior
distributions to draw
nsamples : int
Number of samples to draw when sampling
nsigma : float
Fractional standard deviation used for smoothing
"""

quantiles = [0.5, 0.158655, 0.841345]
nsamples = 100000
nsigma = 0.25


@dataclass
class parameters:
"""
Expand Down Expand Up @@ -254,7 +275,6 @@ def exclude_params(excludeparams):
return parnames

def get_keys(inputparams):

"""
Takes a list of input parameters (or a
single parameter) as strings and returns
Expand Down
Loading

0 comments on commit bf8acc9

Please sign in to comment.