In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sys.path.insert(0, '../')
from snat_sim.utils import *


In [None]:
def get_combined_data(directory, key):
    """Return data from directory of pipeline HDF5 files
    
    Returns concatenated tables from each of the files.
    
    Args:
        directory (Path): The directory to parse data files from
        key: Key of the table in the HDF5 files
        
    Returns:
        A pandas datafram with pipeline data
    """

    dataframes = []
    for file in directory.glob('*.h5'):
        with pd.HDFStore(file, 'r') as datastore:
            dataframes.append(datastore.get(key))
            
    if not dataframes:
        raise ValueError('No h5 files found in given directory')
            
    return pd.concat(dataframes).set_index('snid')
    

In [None]:
def build_data_iter(axes, df, limit=None):
    limit = range(limit + 1) if limit is not None else range(len(axes))
    return (
        (i, fig_row, label, data) for 
        (i, fig_row, (label, data)) in 
        zip(limit, axes, df.iteritems())
    )
    

Todo: Configure linfit, contours, and label based on kwargs

In [None]:
def corner_plot_fit_results(data, x_vals, y_vals, lin_fit_style=None, alpha=.25, figsize=(20, 20)): 
    
    lin_fit_style = dict(color='k', linestyle='--')
    
    fig, axes = plt.subplots(len(x_vals), len(y_vals), figsize=figsize)

    x_iter = build_data_iter(axes, data[x_vals])
    for i, fig_row, xlabel, xdata in x_iter:

        y_iter = build_data_iter(fig_row, data[y_vals], limit=i)
        for j, axis, ylabel, ydata in y_iter:
            axis.scatter(xdata, ydata, alpha=alpha)
            
            if i == j:
                m, b = np.polyfit(xdata, ydata, 1)
                fit = np.poly1d((m, b))
                axis.plot(xdata, fit(xdata), **lin_fit_style, label=f'm={m:.2f} b={b:.2f}')
                
                axis.legend()

    for axis, label in zip(axes[-1,:], x_vals):
        axis.set_xlabel(label)

    for axis, label in zip(axes[:,0], y_vals):
        axis.set_ylabel(label)
    

Todo: Why do I need to dropna ?

In [None]:
base_dir = Path.home() / 'Downloads'
sim_params = get_combined_data(base_dir, '/simulation/params')
fit_results = get_combined_data(base_dir, '/fitting/params')
pipeline_data = sim_params.join(fit_results) 


In [None]:
corner_plot_fit_results(
    data = pipeline_data.dropna(),
    x_vals = ['z', 'x0', 'x1', 'c'],
    y_vals = ['fit_z', 'fit_x0', 'fit_x1', 'fit_c'])
