In [None]:
import pathlib

import astropy.coordinates as coord
import astropy.table as at
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from scipy.spatial import cKDTree
from scipy.stats import binned_statistic
from scipy.interpolate import interp1d
from tqdm import tqdm

from gala.mpl_style import turbo

from totoro.data import datasets
from totoro.abundance_helpers import elem_to_label
from totoro.config import cache_path, plot_path

In [None]:
all_tbls = {}
for data_name, d in datasets.items():
    this_cache_path = cache_path / data_name

    tbls = {}
    for path in this_cache_path.glob('optimize-results-*.csv'):
        try:
            elem = path.name.split('.')[0].split('-')[-1]
        except:
            print(f"FAILED {path}")
            continue
        tbls[elem] = at.Table.read(path)
    
    if len(tbls) > 4:
        all_tbls[data_name] = tbls
    print(data_name, len(tbls))

Unique colors per elem ratio:

In [None]:
all_elems = set()
for tbls in all_tbls.values():
    all_elems = all_elems.union(tbls.keys())
    
elem_to_color = {}
for i, elem in enumerate(all_elems):
    elem_to_color[elem] = turbo(i / len(all_elems))

In [None]:
fiducials = {
    'mdisk_f': 1.,
    'disk_hz': 0.28,
    'zsun': 20.8,
    'vzsun': 7.78
}

colcols = [
    ('mdisk_f', 'disk_hz'), 
    ('mdisk_f', 'vzsun'), 
    ('zsun', 'vzsun')
]

In [None]:
for data_name, tbls in all_tbls.items():
    fig, axes = plt.subplots(1, 3, figsize=(15, 5.5),
                             constrained_layout=True)

    for elem in tbls:
        for i, (col1, col2) in enumerate(colcols):
            ax = axes[i]
            ax.plot(tbls[elem][col1], tbls[elem][col2],
                    ls='none', marker='o', mew=0, ms=4, 
                    label=elem_to_label(elem), color=elem_to_color[elem])
    axes[0].legend()

    axes[0].set_xlabel(r'${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star$')
    axes[1].set_xlabel(r'${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star$')
    axes[2].set_xlabel(r'$z_\odot$ [pc]')

    axes[0].set_ylabel(r'$z_\odot$ [pc]')
    axes[1].set_ylabel(r'$v_{z,\odot}$ ' + f'[{u.km/u.s:latex_inline}]')
    axes[2].set_ylabel(r'$v_{z,\odot}$ ' + f'[{u.km/u.s:latex_inline}]')

    for ax, (col1, col2) in zip(axes, colcols):
        ax.axvline(fiducials[col1], zorder=-10, color='#aaaaaa', linestyle='--')
        ax.axhline(fiducials[col2], zorder=-10, color='#aaaaaa', linestyle='--')

    fig.set_facecolor('w')
    
    fig.suptitle(data_name, fontsize=24)

In [None]:
coord.ICRS

### Error ellipses

In [None]:
# From https://matplotlib.org/devdocs/gallery/statistics/confidence_ellipse.html
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms


def confidence_ellipse(x, y, ax, n_std=1.0, facecolor='none', **kwargs):
    cov = np.cov(x, y)
    pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])
    
    # Using a special case to obtain the eigenvalues of this
    # two-dimensionl dataset.
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
                      facecolor=facecolor, **kwargs)

    # Calculating the stdandard deviation of x from
    # the squareroot of the variance and multiplying
    # with the given number of standard deviations.
    scale_x = np.sqrt(cov[0, 0]) * n_std
    mean_x = np.mean(x)

    # calculating the stdandard deviation of y ...
    scale_y = np.sqrt(cov[1, 1]) * n_std
    mean_y = np.mean(y)

    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(scale_x, scale_y) \
        .translate(mean_x, mean_y)

    ellipse.set_transform(transf + ax.transData)
    return ax.add_patch(ellipse)


def plot_cov_ellipse(m, C, ax, n_std=1.0, facecolor='none', **kwargs):
    pearson = C[0, 1] / np.sqrt(C[0, 0] * C[1, 1])
    
    # Using a special case to obtain the eigenvalues of this
    # two-dimensionl dataset.
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
                      facecolor=facecolor, **kwargs)

    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(n_std * np.sqrt(C[0, 0]), 
               n_std * np.sqrt(C[1, 1])) \
        .translate(m[0], m[1])

    ellipse.set_transform(transf + ax.transData)
    return ax.add_patch(ellipse)

In [None]:
def make_ell_plot(tbls):
    elem_names = tbls.keys()

    means = np.zeros((len(elem_names), 3))
    covs = np.zeros((len(elem_names), 3, 3))
    for j, elem in enumerate(elem_names):
        mask = (np.isfinite(tbls[elem]['mdisk_f']) & 
                np.isfinite(tbls[elem]['zsun']) &
                np.isfinite(tbls[elem]['vzsun']))
        X = np.stack((tbls[elem]['mdisk_f'][mask], 
                      tbls[elem]['zsun'][mask],
                      tbls[elem]['vzsun'][mask]))

        covs[j] = np.cov(X)
        means[j] = np.mean(X, axis=1)

    C = np.linalg.inv(np.sum([np.linalg.inv(cov) for cov in covs], axis=0))
    m = np.sum([C @ np.linalg.inv(cov) @ mean 
                for mean, cov in zip(means, covs)], axis=0)
    
    logdets = [np.linalg.slogdet(cov)[1] for cov in covs]
    norm = mpl.colors.Normalize(vmin=np.nanmin(logdets), 
                                vmax=np.nanmax(logdets), 
                                clip=True)
    norm2 = mpl.colors.Normalize(vmin=-0.2, vmax=1.1)
    def get_alpha(ld):
        return norm2(1 - norm(ld))
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5.5),
                             constrained_layout=True)

    for elem, logdet in zip(elem_names, logdets):
        for i, (col1, col2) in enumerate(colcols):
            ax = axes[i]
            
            color = elem_to_color[elem]

            mask = np.isfinite(tbls[elem][col1]) & np.isfinite(tbls[elem][col2])
            if mask.sum() < 100:
                continue
                
            ell = confidence_ellipse(tbls[elem][col1][mask], 
                                     tbls[elem][col2][mask], 
                                     ax,
                                     n_std=1.,
                                     linewidth=0, facecolor=color, 
                                     alpha=get_alpha(logdet), 
                                     label=elem_to_label(elem))

            ell = confidence_ellipse(tbls[elem][col1][mask], 
                                     tbls[elem][col2][mask], 
                                     ax,
                                     n_std=2.,
                                     linewidth=0, facecolor=color, 
                                     alpha=get_alpha(logdet) / 2)

    for j, i in enumerate([2, 1, 0]):
        mm = np.delete(m, i)
        CC = np.delete(np.delete(C, i, axis=0), i, axis=1)
        ell = plot_cov_ellipse(mm, CC, ax=axes[j], 
                               n_std=1.,
                               linewidth=0, facecolor='k', 
                               alpha=0.5, label='joint', zorder=100)

        ell = plot_cov_ellipse(mm, CC, ax=axes[j], 
                               n_std=2.,
                               linewidth=0, facecolor='k', 
                               alpha=0.2, zorder=100)

    axes[0].set_xlim(0.4, 1.8)
    axes[1].set_xlim(0.4, 1.8)
    axes[2].set_xlim(-80, 30)

    axes[0].set_ylim(-40, 40)
    axes[1].set_ylim(0, 15)
    axes[2].set_ylim(0, 15)

    axes[2].legend(ncol=2)

    axes[0].set_xlabel(r'${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star$')
    axes[1].set_xlabel(r'${\rm M}_{\rm disk} / {\rm M}_{\rm disk}^\star$')
    axes[2].set_xlabel(r'$z_\odot$ [pc]')

    axes[0].set_ylabel(r'$z_\odot$ [pc]')
    axes[1].set_ylabel(r'$v_{z,\odot}$ ' + f'[{u.km/u.s:latex_inline}]')
    axes[2].set_ylabel(r'$v_{z,\odot}$ ' + f'[{u.km/u.s:latex_inline}]')

    for ax, (col1, col2) in zip(axes, colcols):
        ax.axvline(fiducials[col1], zorder=-10, color='#aaaaaa', linestyle='--')
        ax.axhline(fiducials[col2], zorder=-10, color='#aaaaaa', linestyle='--')

    fig.set_facecolor('w')
    
    return fig, axes

In [None]:
for data_name, tbls in all_tbls.items():
    fig, axes = make_ell_plot(tbls)
    fig.suptitle(data_name, fontsize=24)
    fig.savefig(plot_path / data_name / 'bootstrap-error-ellipses.png', dpi=250)