diff --git a/numcosmo_py/app/catalog.py b/numcosmo_py/app/catalog.py index 96f31640..b01bf8bc 100644 --- a/numcosmo_py/app/catalog.py +++ b/numcosmo_py/app/catalog.py @@ -26,6 +26,7 @@ import math import dataclasses from typing import Optional, Annotated, List +from pathlib import Path import typer from rich.table import Table from rich.text import Text @@ -627,6 +628,35 @@ class PlotCorner(LoadCatalog): ), ] = 0 + mcsample_only: Annotated[ + bool, + typer.Option( + help="Generate only the MCSample object.", + hidden=True, + ), + ] = False + + extra_experiment: Annotated[ + Optional[list[Path]], + typer.Option( + help="Run extra experiment.", + ), + ] = None + + extra_mcmc_file: Annotated[ + Optional[list[Path]], + typer.Option( + help="Extra MCMC files.", + ), + ] = None + + extra_burnin: Annotated[ + Optional[list[int]], + typer.Option( + help="Extra burn-ins.", + ), + ] = None + def __post_init__(self) -> None: """Corner plot of the catalog.""" super().__post_init__() @@ -635,7 +665,48 @@ def __post_init__(self) -> None: if self.plot_name is None: self.plot_name = str(self.mcmc_file) mcsample, _, _ = mcat_to_mcsamples(mcat, self.plot_name, indices=self.indices) + self.mcsample = mcsample + + if self.mcsample_only: + self.end_experiment() + return + + if self.extra_experiment is None: + self.extra_experiment = [] + if self.extra_mcmc_file is None: + self.extra_mcmc_file = [] + if self.extra_burnin is None: + self.extra_burnin = [] + if len(self.extra_experiment) != len(self.extra_mcmc_file): + raise ValueError( + "Extra experiments and MCMC files must have the same length." + ) + if len(self.extra_experiment) != len(self.extra_burnin): + raise ValueError( + "Extra experiments and burn-ins must have the same length." + ) + + mcsamples = [mcsample] + for extra_experiment, extra_mcmc_file, extra_burnin in zip( + self.extra_experiment, self.extra_mcmc_file, self.extra_burnin + ): + extra_exp = dataclasses.replace( + self, + experiment=extra_experiment, + mcmc_file=extra_mcmc_file, + burnin=extra_burnin, + log_file=None, + mcsample_only=True, + plot_name=None, + ) + mcsamples.append(extra_exp.mcsample) + + self.plot_mcsamples(mcsamples) + + def plot_mcsamples(self, mcsamples: list[getdist.MCSamples]): + """Plot MCSamples.""" + mcat = self.mcat set_rc_params_article(ncol=2) g = getdist.plots.get_subplot_plotter( width_inch=plt.rcParams["figure.figsize"][0] @@ -645,7 +716,7 @@ def __post_init__(self) -> None: if self.mark_bestfit: bf = np.array(mcat.get_bestfit_row().dup_array())[1:] g.triangle_plot( - [mcsample], shaded=True, markers=bf, title_limit=self.title_limit + mcsamples, shaded=True, markers=bf, title_limit=self.title_limit ) if self.output is not None: