# `rail.creation` demo

_Alex Malz (GCCL@RUB)_ & _Bryce Kalmbach (UW)_

Stolen wholesale from [the XDGMM demo](https://github.com/tholoien/XDGMM/blob/master/Notebooks/Demo.ipynb).

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt

from xdgmm import XDGMM

from sklearn.model_selection import validation_curve
from sklearn.model_selection import ShuffleSplit

from astroML.plotting.tools import draw_ellipse
from astroML.plotting import setup_text_plots
from sklearn.mixture import GaussianMixture as skl_GMM

## Read in simulated data

Expensive to make but includes complex physical effects

In [None]:
# placeholder

N = 2000
np.random.seed(0)

# would have 6 instead of just x, y
# generate the true data
x_true = (1.4 + 2 * np.random.random(N)) ** 2
y_true = 0.1 * x_true ** 2

# add scatter to "true" distribution
dx = 0.1 + 4. / x_true ** 2
dy = 0.1 + 10. / x_true ** 2

x_true += np.random.normal(0, dx, N)
y_true += np.random.normal(0, dy, N)

# add noise to get the "observed" distribution
dx = 0.2 + 0.5 * np.random.random(N)
dy = 0.2 + 0.5 * np.random.random(N)

x = x_true + np.random.normal(0, dx)
y = y_true + np.random.normal(0, dy)

# stack the results for computation
X = np.vstack([x, y]).T
Xerr = np.zeros(X.shape + X.shape[-1:])
diag = np.arange(X.shape[-1])
Xerr[:, diag, diag] = np.vstack([dx ** 2, dy ** 2]).T

## Augment simulated data

To densely populate data space

## Model the augmented data

To get a continuous model of data and truth

### Tune the interpolation model

WARNING: SLOW!!!

In [None]:
# Instantiate an XDGMM model:
xdgmm = XDGMM()

# Define the range of component numbers, and get ready to compute the BIC for each one:
param_range = np.array([1,2,3,4,5,6,7,8,9,10])

# Loop over component numbers, fitting XDGMM model and computing the BIC:
bic, optimal_n_comp, lowest_bic = xdgmm.bic_test(X, Xerr, param_range)

In [None]:
def plot_bic(param_range,bics,lowest_comp):
    plt.clf()
    setup_text_plots(fontsize=16, usetex=True)
    fig = plt.figure(figsize=(12, 6))
    plt.bar(param_range-0.25,bics,color='blue',width=0.5)
    plt.text(lowest_comp, bics.min() * 0.97 + .03 * bics.max(), '*',
             fontsize=14, ha='center')

    plt.xticks(param_range)
    plt.ylim(bics.min() - 0.01 * (bics.max() - bics.min()),
             bics.max() + 0.01 * (bics.max() - bics.min()))
    plt.xlim(param_range.min() - 1, param_range.max() + 1)

    plt.xticks(param_range,fontsize=14)
    plt.yticks(fontsize=14)


    plt.xlabel('Number of components',fontsize=18)
    plt.ylabel('BIC score',fontsize=18)

    plt.show()

In [None]:
plot_bic(param_range, bic, optimal_n_comp)

In [None]:
param_range = np.array([1,2,3,4,5,6,7,8,9,10])
shuffle_split = ShuffleSplit(3, test_size=0.3,random_state=0)

train_scores,test_scores = validation_curve(xdgmm, X=X, y=Xerr, 
                                            param_name="n_components",
                                            param_range=param_range,
                                            n_jobs=3,
                                            cv=shuffle_split,
                                            verbose=1)

train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)

In [None]:
def plot_val_curve(param_range, train_mean, train_std, test_mean,
                   test_std):
    plt.clf()
    setup_text_plots(fontsize=16, usetex=True)
    fig=plt.figure(figsize=(12,8))
    plt.plot(param_range, train_mean, label="Training score",
             color="red")
    plt.fill_between(param_range, train_mean - train_std,
                     train_mean + train_std, alpha=0.2, color="red")
    plt.plot(param_range, test_mean,label="Cross-validation score",
             color="green")
    plt.fill_between(param_range, test_mean - test_std,
                     test_mean + test_std, alpha=0.2, color="green")

    plt.legend(loc="best")
    plt.xlabel("Number of Components", fontsize=18)
    plt.ylabel("Score", fontsize=18)
    plt.xlim(param_range.min(),param_range.max())
    plt.show()

In [None]:
plot_val_curve(param_range, train_scores_mean, train_scores_std, test_scores_mean, test_scores_std)

### Interpolate the augmented data

In [None]:
xdgmm.n_components = optimal_n_comp
xdgmm = xdgmm.fit(X, Xerr)

In [None]:
xdgmm.save_model('demo_model.fit')

## Introduce systematics

i.e. optionally degrade the model for realistic complexity

In [None]:
# Read model into an existing XDGMM object
xdgmm.read_model('demo_model.fit')

# Initialize a new XDGMM object using the model
xdgmm2 = XDGMM(filename='demo_model.fit')

# Comparison --- the arrays should be the same.
print(xdgmm.weights)
print(xdgmm2.weights)

## Emulate data and truth

Draw data and truth from continuous model

In [None]:
sample = xdgmm.sample(N)

In [None]:
def plot_sample(x_true, y_true, x, y, sample, xdgmm):
    setup_text_plots(fontsize=16, usetex=True)
    plt.clf()
    fig = plt.figure(figsize=(12, 9))
    fig.subplots_adjust(left=0.1, right=0.95,
                        bottom=0.1, top=0.95,
                        wspace=0.02, hspace=0.02)

    ax1 = fig.add_subplot(221)
    ax1.scatter(x_true, y_true, s=4, lw=0, c='k')

    ax2 = fig.add_subplot(222)

    ax2.scatter(x, y, s=4, lw=0, c='k')

    ax3 = fig.add_subplot(223)
    ax3.scatter(sample[:, 0], sample[:, 1], s=4, lw=0, c='k')

    ax4 = fig.add_subplot(224)
    for i in range(xdgmm.n_components):
        draw_ellipse(xdgmm.mu[i], xdgmm.V[i], scales=[2], ax=ax4,
                     ec='k', fc='gray', alpha=0.2)

    titles = ["True Distribution", "Noisy Distribution",
              "Extreme Deconvolution\n  resampling",
            "Extreme Deconvolution\n  cluster locations"]

    ax = [ax1, ax2, ax3, ax4]

    for i in range(4):
        ax[i].set_xlim(-1, 13)
        ax[i].set_ylim(-6, 16)

        ax[i].xaxis.set_major_locator(plt.MultipleLocator(4))
        ax[i].yaxis.set_major_locator(plt.MultipleLocator(5))

        ax[i].text(0.05, 0.95, titles[i],
                   ha='left', va='top', transform=ax[i].transAxes)

        if i in (0, 1):
            ax[i].xaxis.set_major_formatter(plt.NullFormatter())
        else:
            ax[i].set_xlabel('$x$', fontsize = 18)

        if i in (1, 3):
            ax[i].yaxis.set_major_formatter(plt.NullFormatter())
        else:
            ax[i].set_ylabel('$y$', fontsize = 18)

    plt.show()

In [None]:
plot_sample(x_true, y_true, x, y, sample, xdgmm)

## Write out photometry

`ceci` HDF5 format for `rail.estimation` soon, CSV for now

## Evaluate true redshift posterior of sampled data

Evaluate the continuous model at the 

In [None]:
cond_X = np.array([np.nan, 1.5])
cond_Xerr = np.array([0.0,0.05])
cond_xdgmm = xdgmm.condition(X_input = cond_X,Xerr_input = cond_Xerr)

# Compare the conditioned model to the original:
print(xdgmm.weights)
print(cond_xdgmm.weights)
print("\n")
print(xdgmm.mu)
print(cond_xdgmm.mu)

In [None]:
def plot_cond_model(xdgmm, cond_xdgmm, y):
    plt.clf()
    setup_text_plots(fontsize=16, usetex=True)
    fig = plt.figure(figsize=(12, 9))

    ax1 = fig.add_subplot(111)
    for i in range(xdgmm.n_components):
        draw_ellipse(xdgmm.mu[i], xdgmm.V[i], scales=[2], ax=ax1,
                     ec='k', fc='gray', alpha=0.2)

    ax1.plot([-2,15],[y,y],color='blue',linewidth=2)
    ax1.set_xlim(-1, 13)
    ax1.set_ylim(-6, 16)
    ax1.set_xlabel('$x$', fontsize = 18)
    ax1.set_ylabel('$y$', fontsize = 18)

    ax2 = ax1.twinx()
    x = np.array([np.linspace(-2,14,1000)]).T

    gmm=skl_GMM(n_components = cond_xdgmm.n_components,
                covariance_type = 'full')
    gmm.means_ = cond_xdgmm.mu
    gmm.weights_ = cond_xdgmm.weights
    gmm.covars_ = cond_xdgmm.V

    logprob, responsibilities = gmm.score_samples(x)

    pdf = np.exp(logprob)
    ax2.plot(x, pdf, color='red', linewidth = 2,
             label='Cond. dist. of $x$ given $y='+str(y)+'\pm 0.05$')
    ax2.legend()
    ax2.set_ylabel('Probability', fontsize= 18 )
    ax2.set_ylim(0, 0.52)
    ax1.set_xlim(-1, 13)
    plt.show()

def plot_cond_sample(x, y):
    plt.clf()
    setup_text_plots(fontsize=16, usetex=True)
    fig = plt.figure(figsize=(12, 9))

    plt.hist(x, 50, histtype='step', color='red',lw=2)

    plt.ylim(0,70)
    plt.xlim(-1,13)

    plt.xlabel('$x$', fontsize=18)
    plt.ylabel('Number of Points', fontsize=18)

    plt.show()

def plot_conditional_predictions(y, true_x, predicted_x):
    plt.clf()
    setup_text_plots(fontsize=16, usetex=True)
    fig = plt.figure(figsize=(12, 9))

    plt.scatter(true_x, y, color='red', s=4, marker='o',
                label="True Distribution")
    plt.scatter(predicted_x, y, color='blue', s=4, marker='o',
                label="Predicted Distribution")

    plt.xlim(-1, 13)
    plt.ylim(-6, 16)
    plt.legend(loc=2, scatterpoints=1)

    plt.xlabel('$x$', fontsize = 18)
    plt.ylabel('$y$', fontsize = 18)
    plt.show()


In [None]:
plot_cond_model(xdgmm, cond_xdgmm, 1.5)

In [None]:
cond_sample = cond_xdgmm.sample(1000)
y = np.ones(1000)*1.5
plot_cond_sample(cond_sample,y)

In [None]:
# Simulate a dataset:
true_sample = xdgmm.sample(1000)
true_x = true_sample[:,0]
y = true_sample[:,1]

# Predict x values given y values:
predicted_x = np.array([])
for this_y in y:
    # Specify y-conditioning to apply to P(x,y):
    on_this = np.array([np.nan,this_y])
    # Compute conditional PDF P(x|y):
    cond_gmm = xdgmm.condition(on_this)
    # Draw a sample x value from this PDF, and add it to the growing list
    predicted_x = np.append(predicted_x, cond_gmm.sample())

# Plot the two datasets, to compare the true x and the predicted x:
plot_conditional_predictions(y, true_x, predicted_x)

## Write out true redshifts and posteriors

Separate for blind estimation, must be compatible with `rail.evaluation`