Skip to content

Commit

Permalink
Merge branch 'master' into twofluids
Browse files Browse the repository at this point in the history
  • Loading branch information
vitenti committed May 19, 2024
2 parents c3bdac6 + 8c8f4ff commit 739dd4f
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion numcosmo_py/app/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import math
import dataclasses
from typing import Optional, Annotated, List
from pathlib import Path
import typer
from rich.table import Table
from rich.text import Text
Expand Down Expand Up @@ -627,6 +628,35 @@ class PlotCorner(LoadCatalog):
),
] = 0

mcsample_only: Annotated[
bool,
typer.Option(
help="Generate only the MCSample object.",
hidden=True,
),
] = False

extra_experiment: Annotated[
Optional[list[Path]],
typer.Option(
help="Run extra experiment.",
),
] = None

extra_mcmc_file: Annotated[
Optional[list[Path]],
typer.Option(
help="Extra MCMC files.",
),
] = None

extra_burnin: Annotated[
Optional[list[int]],
typer.Option(
help="Extra burn-ins.",
),
] = None

def __post_init__(self) -> None:
"""Corner plot of the catalog."""
super().__post_init__()
Expand All @@ -635,7 +665,48 @@ def __post_init__(self) -> None:
if self.plot_name is None:
self.plot_name = str(self.mcmc_file)
mcsample, _, _ = mcat_to_mcsamples(mcat, self.plot_name, indices=self.indices)
self.mcsample = mcsample

if self.mcsample_only:
self.end_experiment()
return

if self.extra_experiment is None:
self.extra_experiment = []
if self.extra_mcmc_file is None:
self.extra_mcmc_file = []
if self.extra_burnin is None:
self.extra_burnin = []

if len(self.extra_experiment) != len(self.extra_mcmc_file):
raise ValueError(
"Extra experiments and MCMC files must have the same length."
)
if len(self.extra_experiment) != len(self.extra_burnin):
raise ValueError(
"Extra experiments and burn-ins must have the same length."
)

mcsamples = [mcsample]
for extra_experiment, extra_mcmc_file, extra_burnin in zip(
self.extra_experiment, self.extra_mcmc_file, self.extra_burnin
):
extra_exp = dataclasses.replace(
self,
experiment=extra_experiment,
mcmc_file=extra_mcmc_file,
burnin=extra_burnin,
log_file=None,
mcsample_only=True,
plot_name=None,
)
mcsamples.append(extra_exp.mcsample)

self.plot_mcsamples(mcsamples)

def plot_mcsamples(self, mcsamples: list[getdist.MCSamples]):
"""Plot MCSamples."""
mcat = self.mcat
set_rc_params_article(ncol=2)
g = getdist.plots.get_subplot_plotter(
width_inch=plt.rcParams["figure.figsize"][0]
Expand All @@ -645,7 +716,7 @@ def __post_init__(self) -> None:
if self.mark_bestfit:
bf = np.array(mcat.get_bestfit_row().dup_array())[1:]
g.triangle_plot(
[mcsample], shaded=True, markers=bf, title_limit=self.title_limit
mcsamples, shaded=True, markers=bf, title_limit=self.title_limit
)

if self.output is not None:
Expand Down

0 comments on commit 739dd4f

Please sign in to comment.