Skip to content

Commit

Permalink
Added support for parameter filtering in numcosmo app. (#159)
Browse files Browse the repository at this point in the history
* Added support for parameter filtering in numcosmo app.

* More fftlog calibrate.

* Support for different typer behaviors.
  • Loading branch information
vitenti committed May 19, 2024
1 parent db822d2 commit e9c949f
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 56 deletions.
43 changes: 15 additions & 28 deletions numcosmo_py/app/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def __post_init__(self) -> None:
param_diag.add_column(
"Parameter", justify="left", style=desc_color, vertical="middle"
)
param_diag_matrix.append(
[mcat.col_full_name(i) for i in range(self.total_columns)]
)
param_diag_matrix.append([mcat.col_full_name(i) for i in self.indices])

# Values color
val_color = values_color
Expand All @@ -129,25 +127,21 @@ def __post_init__(self) -> None:
"Best-fit", justify="left", style=val_color, vertical="middle"
)
param_diag_matrix.append(
[f"{best_fit_vec.get(i): .6g}" for i in range(self.total_columns)]
[f"{best_fit_vec.get(i): .6g}" for i in self.indices]
)

# Parameter mean
param_diag.add_column(
"Mean", justify="left", style=val_color, vertical="middle"
)
param_diag_matrix.append(
[f"{fs.get_mean(i): .6g}" for i in range(self.total_columns)]
)
param_diag_matrix.append([f"{fs.get_mean(i): .6g}" for i in self.indices])

# Standard Deviation

param_diag.add_column(
"Standard Deviation", justify="left", style=val_color, vertical="middle"
)
param_diag_matrix.append(
[f"{fs.get_sd(i): .6g}" for i in range(self.total_columns)]
)
param_diag_matrix.append([f"{fs.get_sd(i): .6g}" for i in self.indices])

if self.nitems >= 10:
# Mean Standard Deviation
Expand All @@ -161,7 +155,7 @@ def __post_init__(self) -> None:

mean_sd_array = [
np.sqrt(fs.get_var(i) * tau_vec.get(i) / fs.nitens())
for i in range(self.total_columns)
for i in self.indices
]
param_diag_matrix.append([f"{mean_sd: .6g}" for mean_sd in mean_sd_array])

Expand All @@ -180,17 +174,13 @@ def __post_init__(self) -> None:
param_diag.add_column(
"tau", justify="left", style=val_color, vertical="middle"
)
param_diag_matrix.append(
[f"{tau_vec.get(i): .6g}" for i in range(self.total_columns)]
)
param_diag_matrix.append([f"{tau_vec.get(i): .6g}" for i in self.indices])

if self.nchains > 1:
# Gelman Rubin
gelman_rubin_row = []
gelman_rubin_row.append("Gelman-Rubin (G&B) Shrink Factor (R-1)")
skf = [
mcat.get_param_shrink_factor(i) - 1 for i in range(self.total_columns)
]
skf = [mcat.get_param_shrink_factor(i) - 1 for i in self.indices]
gelman_rubin_row.append("NA")
gr_worst = int(np.argmin(skf))
gelman_rubin_row.append(
Expand All @@ -207,7 +197,7 @@ def __post_init__(self) -> None:

# Constant Break

cb = [self.stats.estimate_const_break(i) for i in range(self.total_columns)]
cb = [self.stats.estimate_const_break(i) for i in self.indices]
cb_worst = int(np.argmax(cb))
const_break_row = []
const_break_row.append("Constant Break (CB) (iterations, points)")
Expand Down Expand Up @@ -246,7 +236,7 @@ def __post_init__(self) -> None:
param_diag_matrix.append(
[
f"{ess_vec.get(i):.0f} {ess_vec.get(i) * self.nchains:.0f}"
for i in range(self.total_columns)
for i in self.indices
]
)

Expand Down Expand Up @@ -282,10 +272,7 @@ def __post_init__(self) -> None:
style=val_color,
)
param_diag_matrix.append(
[
f"{(1.0 - hw_vec.get(i)) * 100.0:.1f}"
for i in range(self.total_columns)
]
[f"{(1.0 - hw_vec.get(i)) * 100.0:.1f}" for i in self.indices]
)

for row in np.array(param_diag_matrix).T:
Expand All @@ -297,14 +284,14 @@ def __post_init__(self) -> None:

covariance_matrix = Table(title="Covariance Matrix", expand=False)
covariance_matrix.add_column("Parameter", justify="right", style="bold")
for i in range(self.total_columns):
for i in self.indices:
covariance_matrix.add_column(
mcat.col_name(i).split(":")[-1], justify="right"
)

for i in range(self.total_columns):
for i in self.indices:
row = [mcat.col_name(i).split(":")[-1]]
for j in range(self.total_columns):
for j in self.indices:
cov_ij = fs.get_cor(i, j)
cor_ij_string = f"{cov_ij * 100.0: 3.0f}%"
styles_array = [
Expand Down Expand Up @@ -611,7 +598,7 @@ class PlotCorner(LoadCatalog):

plot_name: Annotated[
Optional[str],
typer.Argument(
typer.Option(
help="Name of the plot file.",
),
] = None
Expand Down Expand Up @@ -646,7 +633,7 @@ def __post_init__(self) -> None:
mcat = self.mcat
if self.plot_name is None:
self.plot_name = str(self.mcmc_file)
mcsample, _, _ = mcat_to_mcsamples(mcat, self.plot_name)
mcsample, _, _ = mcat_to_mcsamples(mcat, self.plot_name, indices=self.indices)

set_rc_params_article(ncol=2)
g = getdist.plots.get_subplot_plotter(
Expand Down
72 changes: 63 additions & 9 deletions numcosmo_py/app/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

"""NumCosmo APP dataclasses and subcommands to load data.
This module contains dataclasses and subcommands to load data from files.
"""

import dataclasses
Expand All @@ -36,10 +37,13 @@
from numcosmo_py.sampling import set_ncm_console


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True)
class LoadExperiment:
"""Common block for commands that load an experiment. All commands that load an
experiment should inherit from this class."""
"""Load an experiment file.
Common block for commands that load an experiment. All commands that load an
experiment should inherit from this class.
"""

experiment: Annotated[
Path, typer.Argument(help="Path to the experiment file to fit.")
Expand All @@ -52,8 +56,9 @@ class LoadExperiment:
help=(
"If given, the product file is written, the file name is the same as "
"the experiment file with the extension .product.yaml. "
"This option is incompatible with the output and starting-point options "
"since the product file contains the output and starting point."
"This option is incompatible with the output and starting-point "
"options since the product file contains the output and starting "
"point."
),
),
] = False
Expand All @@ -79,6 +84,7 @@ class LoadExperiment:
] = None

def __post_init__(self) -> None:
"""Initialize the experiment and load the data."""
ser = Ncm.Serialize.new(Ncm.SerializeOpt.CLEAN_DUP)

builders_file = self.experiment.with_suffix(".builders.yaml")
Expand Down Expand Up @@ -167,9 +173,10 @@ def __post_init__(self) -> None:
self.mset = mset

def _load_saved_mset(self) -> Optional[Ncm.MSet]:
"""Loads the saved model set from the starting point file "
"or the product file."""
"""Load the saved model.
Load the saved model-set from the starting point file or the product file.
"""
if self.starting_point is not None:
if not self.starting_point.exists():
raise RuntimeError(
Expand Down Expand Up @@ -199,15 +206,15 @@ def _load_saved_mset(self) -> Optional[Ncm.MSet]:
return None

def end_experiment(self):
"""Ends the experiment and writes the output file."""
"""End the experiment and writes the output file."""
if self.output is not None:
ser = Ncm.Serialize.new(Ncm.SerializeOpt.CLEAN_DUP)
ser.dict_str_to_yaml_file(
self.output_dict, self.output.absolute().as_posix()
)


@dataclasses.dataclass
@dataclasses.dataclass(kw_only=True)
class LoadCatalog(LoadExperiment):
"""Analyzes the results of a MCMC run."""

Expand All @@ -226,7 +233,22 @@ class LoadCatalog(LoadExperiment):
),
] = 0

include: Annotated[
Optional[list[str]],
typer.Option(
help="List of parameters and or model names to include in the analysis.",
),
] = None

exclude: Annotated[
Optional[list[str]],
typer.Option(
help="List of parameters and or model names to exclude from the analysis.",
),
] = None

def __post_init__(self) -> None:
"""Initialize the MCMC file and load the data."""
super().__post_init__()
if self.mcmc_file is None:
raise RuntimeError("No MCMC file given.")
Expand All @@ -248,6 +270,8 @@ def __post_init__(self) -> None:
self.total_columns: int = self.fparams_len + self.nadd_vals
self.nchains: int = self.mcat.nchains()

self._extract_indices()

self.full_stats: Ncm.StatsVec = self.mcat.peek_pstats()
assert isinstance(self.full_stats, Ncm.StatsVec)

Expand All @@ -259,3 +283,33 @@ def __post_init__(self) -> None:
assert isinstance(self.stats, Ncm.StatsVec)

self.nitems: int = self.stats.nitens()

def _extract_indices(self):
"""Extract the indices to include in the analysis."""
if self.include is None:
self.include = []
if self.exclude is None:
self.exclude = []
assert self.include is not None
assert self.exclude is not None
if not self.include and not self.exclude:
self.indices = list(range(self.total_columns))
else:
self.indices = []
if self.include and self.exclude:
for i in range(self.total_columns):
name = self.mcat.col_full_name(i)
if any(s in name for s in self.include) and not any(
s in name for s in self.exclude
):
self.indices.append(i)
elif self.include:
for i in range(self.total_columns):
name = self.mcat.col_full_name(i)
if any(s in name for s in self.include):
self.indices.append(i)
else:
for i in range(self.total_columns):
name = self.mcat.col_full_name(i)
if not any(s in name for s in self.exclude):
self.indices.append(i)
42 changes: 25 additions & 17 deletions numcosmo_py/plotting/getdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@

"""NumCosmoPy getdist utilities."""

from typing import List, Tuple
from typing import Optional
import warnings

import re
import numpy as np
import numpy.typing as npt
from getdist import MCSamples

from numcosmo_py import Ncm
Expand All @@ -36,13 +37,13 @@
def mcat_to_mcsamples(
mcat: Ncm.MSetCatalog,
name: str,
asinh_transform: Tuple[int, ...] = (),
asinh_transform: tuple[int, ...] = (),
burnin: int = 0,
thin: int = 1,
collapse: bool = False,
) -> Tuple[MCSamples, np.ndarray, np.ndarray]:
"""Converts a Ncm.MSetCatalog to a getdist.MCSamples object."""

indices: Optional[npt.NDArray[np.int64]] = None,
) -> tuple[MCSamples, np.ndarray, np.ndarray]:
"""Convert a Ncm.MSetCatalog to a getdist.MCSamples object."""
nchains: int = mcat.nchains()
max_time: int = mcat.max_time()

Expand All @@ -65,30 +66,37 @@ def mcat_to_mcsamples(
]
)

params: List[str] = [mcat.col_symb(i) for i in range(mcat.ncols())]
if indices is not None:
indices_array = np.array(indices)
else:
indices_array = np.arange(mcat.ncols())

# Get the -2 log likelihood column
m2lnL: int = mcat.get_m2lnp_var() # pylint:disable=invalid-name
posterior: np.ndarray = 0.5 * rows[:, m2lnL]
indices_array = indices_array[indices_array != m2lnL]

rows = np.delete(rows, m2lnL, 1)
params = list(np.delete(params, m2lnL, 0))
names = [re.sub("[^A-Za-z0-9_]", "", param) for param in params]

# Get the weights column
weights = None
if mcat.weighted():
# Original index is nadd_vals - 1,
# but since we removed m2lnL it is now nadd_vals - 2
weight_index = mcat.nadd_vals() - 2
assert weight_index >= 0
weights = rows[:, weight_index]
rows = np.delete(rows, weight_index, 1)
params = list(np.delete(params, weight_index, 0))
names = list(np.delete(names, weight_index, 0))
indices_array = indices_array[indices_array != weight_index]

rows = rows[:, indices_array]
param_symbols: list[str] = list(mcat.col_symb(i) for i in indices_array)
param_names: list[str] = [
re.sub("[^A-Za-z0-9_]", "", param) for param in param_symbols
]

if len(asinh_transform) > 0:
rows[:, asinh_transform] = np.arcsinh(rows[:, asinh_transform])
for i in asinh_transform:
params[i] = f"\\mathrm{{sinh}}^{{-1}}({params[i]})"
names[i] = f"asinh_{names[i]}"
param_symbols[i] = f"\\mathrm{{sinh}}^{{-1}}({param_symbols[i]})"
param_names[i] = f"asinh_{param_names[i]}"

if not collapse:
split_chains = np.array([rows[(burnin + n) :: nchains] for n in range(nchains)])
Expand All @@ -111,8 +119,8 @@ def mcat_to_mcsamples(
mcsample = MCSamples(
samples=split_chains,
loglikes=split_posterior,
names=names,
labels=params,
names=param_names,
labels=param_symbols,
label=name,
weights=split_weights,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ncm_fftlog.c
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ test_ncm_fftlog_gausswin2_new (TestNcmFftlog *test, gconstpointer pdata)
void
test_ncm_fftlog_sbessel_j_new (TestNcmFftlog *test, gconstpointer pdata)
{
const guint N = g_test_rand_int_range (5800, 6000);
const guint N = g_test_rand_int_range (7800, 8000);
const guint ell = g_test_rand_int_range (0, 5);
NcmFftlog *fftlog = NCM_FFTLOG (ncm_fftlog_sbessel_j_new (ell, 0.0, 0.0, 20.0, N));
TestNcmFftlogK *argK = g_new (TestNcmFftlogK, 1);
Expand Down Expand Up @@ -365,7 +365,7 @@ test_ncm_fftlog_sbessel_j_new (TestNcmFftlog *test, gconstpointer pdata)
void
test_ncm_fftlog_sbessel_j_q0_5_new (TestNcmFftlog *test, gconstpointer pdata)
{
const guint N = g_test_rand_int_range (5800, 6000);
const guint N = g_test_rand_int_range (7800, 8000);
const guint ell = g_test_rand_int_range (0, 5);
NcmFftlog *fftlog = NCM_FFTLOG (ncm_fftlog_sbessel_j_new (ell, 0.0, 0.0, 20.0, N));
TestNcmFftlogK *argK = g_new (TestNcmFftlogK, 1);
Expand Down

0 comments on commit e9c949f

Please sign in to comment.