In [2]:
import sys

sys.path.insert(0, '..')


In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy import units as u
from astropy.coordinates import SkyCoord
from scipy.stats.stats import pearsonr
from egon.pipeline import Pipeline
from egon.mock import MockTarget
from typing import *
from snat_sim.types import *
from snat_sim.plasticc import PLAsTICC
from snat_sim.pipeline.nodes import *

from snat_sim.pipeline import FittingPipeline
from snat_sim.models import SNModel, ReferenceCatalog, StaticPWVTrans


In [3]:
from typing import *

from egon.pipeline import Pipeline
from egon.mock import MockTarget

from snat_sim.types import *
from snat_sim.plasticc import PLAsTICC
from snat_sim.pipeline.nodes import *
from snat_sim.models import SNModel, ReferenceCatalog, StaticPWVTrans




In [4]:
def validate(
    cadence: str,
    sim_model: SNModel,
    fit_model: SNModel,
    vparams: List[str],
    out_path: Union[str, Path],
    fitting_pool: int = 1,
    simulation_pool: int = 1,
    writing_pool: int = 1,
    bounds: Dict[str, Tuple[float, float]] = None,
    max_queue: int = 100,
    iter_lim: int = float('inf'),
    catalog = None,
    add_scatter: bool = True,
    fixed_snr: Optional[float] = None,
    overwrite: bool = False,
    write_lc_sims: bool = False
) -> None:
    """Fit light-curves using multiple processes and combine results into an output file

    Args:
        cadence: Cadence to use when simulating light-curves
        sim_model: Model to use when simulating light-curves
        fit_model: Model to use when fitting light-curves
        vparams: List of parameter names to vary in the fit
        out_path: Path to write results to
        fitting_pool: Number of child processes allocated to simulating light-curves
        simulation_pool: Number of child processes allocated to fitting light-curves
        bounds: Bounds to impose on ``fit_model`` parameters when fitting light-curves
        max_queue: Maximum number of light-curves to store in pipeline at once
        iter_lim: Limit number of processed light-curves (Useful for profiling)
        catalog: Reference star catalog to calibrate simulated supernova with
        add_scatter: Add randomly generated scatter to simulated light-curve points
        fixed_snr: Simulate light-curves with a fixed signal to noise ratio
        overwrite: Whether to allow overwriting an existing output file
        write_lc_sims: Include simulated light_curves in the
    """

    # Define the nodes of the analysis pipeline
    cadence = PLAsTICC(cadence, model=11)
    load_plastic = LoadPlasticcCadence(cadence, iter_lim=iter_lim)
    write_to_disk = MockTarget()

    simulate_light_curves = SimulateLightCurves(
        sn_model=sim_model,
        catalog=catalog,
        num_processes=simulation_pool,
        add_scatter=add_scatter,
        fixed_snr=fixed_snr
    )

    fit_light_curves = FitLightCurves(
        sn_model=fit_model,
        vparams=vparams,
        bounds=bounds,
        num_processes=fitting_pool)

    # Connect pipeline nodes together
    load_plastic.output.connect(simulate_light_curves.input)
    simulate_light_curves.success_output.connect(fit_light_curves.input)
    simulate_light_curves.failure_output.connect(write_to_disk.input)
    fit_light_curves.success_output.connect(write_to_disk.input)
    fit_light_curves.failure_output.connect(write_to_disk.input)
    
    load_plastic.execute()
    simulate_light_curves.execute()
    fit_light_curves.execute()
    write_to_disk.execute()
    return write_to_disk.accumulated_data
    
    

In [5]:
# Model for the atmospheric tranmission
transmission_effect = StaticPWVTrans()
transmission_effect.set(pwv=4)

# Models for fitting/simulating SNE
simulation_model = SNModel('salt2-extended')
fitting_model = SNModel('salt2-extended')
for model in (simulation_model, fitting_model):
    model.add_effect(transmission_effect, 'pwv', 'obs')

# Model for the reference stellar catalag
stellar_catalog = ReferenceCatalog('G2', 'M5', 'K2')

data = validate(
    cadence='alt_sched_rolling',
    sim_model=simulation_model,
    fit_model=fitting_model,
    catalog=stellar_catalog,
    out_path='temp.h5',
    vparams=('x0', 'x1', 'c')
)


alt_sched_rolling: 0it [00:00, ?it/s]


In [None]:
def get_combined_data_table(directory, key):
    """Return a single data table from directory of HDF5 files
    
    Returns concatenated tables from each of the files.
    
    Args:
        directory (Path): The directory to parse data files from
        key        (str): Key of the table in the HDF5 files
        
    Returns:
        A pandas datafram with pipeline data
    """

    
    if not (h5_files := list(directory.glob('*.h5'))):
        raise ValueError(f'No h5 files found in {directory}')
    
    dataframes = []    
    for file in h5_files:
        print(file)
        with pd.HDFStore(file, 'r') as datastore:
            dataframes.append(datastore.get(key))
            
    return pd.concat(dataframes).set_index('snid')


def load_pipeline_data(directory):
    """Return the combined input and output parameters from a pipeline run
    
    Args:
        directory          (Path): The directory to parse data files from
        
    Returns:
        A pandas datafram with pipeline data
    """
    
    sim_params = get_combined_data_table(directory, '/simulation/params')
    fit_results = get_combined_data_table(directory, '/fitting/params')

    # Combine the imput simulation parameters and the fit results into a single dataframe
    # The keys in ``fit_results`` are expected to be a proper subset of ``sim_params``
    # so we left join on ``fit_results``
    pipeline_data = fit_results.join(sim_params)

    # Join results for failed fit results will be nan.
    return pipeline_data.replace(-99.99, np.nan)


In [None]:
data_dir = Path('.').resolve()
load_pipeline_data(data_dir)


In [None]:
d = pd.HDFStore('/home/djperrefort/Github/SN-PWV/notebooks/.temp_fn0.h5', 'r')

In [None]:
d.keys()