In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
import numpy as np

First I import that data and selectivly cut off all stars with a relative parallax error above a threshold of  0.15.  
I've included the full dataset, and the training & test sets (That have had the binaries cut off )

In [None]:
filename = 'Full_DATA_r12.csv'  # file with data.
data = pd.read_csv(filename)
cutoff = 0.15  # Set this to the maximum allowed relative error
data = data[(data['Parallax_error'] < cutoff)]

Now we have to cut all the extra data that model doesn't use for predictions but that we still may want for 
plotting 



In [None]:
Apparent = data.pop('Apparent')
Extinction = data.pop('Extinction')
parallax_error = data.pop('Parallax_error')
parallax = data.pop('parallax')
g = data.pop('G')
G_mag = g - 5 * np.log10((1/(parallax/1000)) / 10)
BP = data.pop('BP')
RP = data.pop('RP')
vscatter = data.pop('vscatter')


distance = (1/(parallax/1000))  # in Parsecs

labels = data.pop('Abs_MAG')  # This is the true values of what the model is predicting

Next I normalize the data using the stats used during the training of the model.

In [None]:
train_stats = pd.read_csv('trainstats.csv')  # reads file with stats used during model training

 

def norm(x):
    return (x - train_stats['mean']) / train_stats['std']

normed_data = norm(data)

Finally I load the model and make predictions based off the data

In [None]:
modelfile = 'singlelayerNet_15perc_cutBinaries.h5'  # file that contains the trained model

model = tf.keras.models.load_model(modelfile)  # loads model into Tensorflow model object

predictions = model.predict(normed_data).flatten()  # uses data with model to predict the absolute magnitude of stars

distance_prediction = 10*10**((predictions+Extinction-Apparent)/-5)

distance_regression = (distance - distance_prediction) / distance


Note: The model can takes any N x 3 array with the coulnms being TEFF, Grav, and Metal (In that order and normalized)
and still work. 

I've also included some of the plotting routines that I use below as functions

In [None]:

def plot_hr(Mag):
    """
    Creates an HR diagram for data. 
    
    :param Mag: Enter either the labels array to see the true HR diagram, 
    or enter the Mag predictions that come out of the model
    """
    
    plt.figure()
    plt.scatter(data['TEFF'], Mag, alpha=0.1, s = 0.1)
    plt.xlim(7500, 3000)
    plt.xscale('log')
    plt.ylim(10, -12)
    plt.xlabel('Temp (k)')
    plt.ylabel('Absolute Magnitude (K-band)')
    plt.title('HR diagram')
    plt.show()



In [None]:
def regression_pergrav():
    """
    Makes 5 regression plots for binnings of log(G)

    """
    for i in range(1, 6):
        plt.figure()
        plt.scatter(distance[(data['Grav'] < i) & (data['Grav'] > i - 1)],
                    distance_regression[(data['Grav'] < i) & (data['Grav'] > i - 1)], s=0.01)
        plt.ylim([-1, 1])
        plt.xlim([-100, 4000])
        plt.title('regression plot for Grav:' + str(i))
        plt.show()

In [None]:
def plot_regression_hist():
    """
    Makes 5 regression histogram plots for binnings of log(G)

    """
    for i in range(1, 6):
        plt.figure()
        plt.hist(distance_regression[(data['Grav'] < i) & (data['Grav'] > i - 1)], bins=200)
        plt.title('regression histogram for Grav:' + str(i) + ' -No flattening model')
        plt.xlim(-1, 1)
        plt.show()
