In [1]:
import os, glob
import numpy as np
from tqdm import tqdm

from astropy.table import Table
from scipy.interpolate import interp1d
import pandas as pd

import tensorflow as tf
from tensorflow import keras

from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import KFold

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import seaborn as sns

from bokeh.plotting import figure
from bokeh.layouts import row, column
from bokeh.io import output_notebook, show
output_notebook()


# SETUP CROSSVAL

In [17]:
### DO TRAIN-VAL SPLIT

### TOTAL NUMBER OF FLARES
num_flares = len(LABELS)

### 80% FOR TRAINING, 10% FOR VALIDATION, 10% FOR TESTING
trainval_cutoff = int(0.90 * num_flares)
x_trainval = DATA[0:trainval_cutoff]
y_trainval = LABELS[0:trainval_cutoff]
p_trainval = PEAKS[0:trainval_cutoff]
t_trainval = TICS[0:trainval_cutoff]

### PRINT RESULTS TO CHECK
print("Partitioned {} out of {} flares into train-val set".format(len(y_trainval), num_flares))

Partitioned 21942 out of 24380 TCEs into train-val set


In [19]:
### SETUP K-FOLDS

### DIVIDE TRAIN AND VALIDATION SETS INTO 5 FOLDS
kfolds = 5
kf = KFold(n_splits=kfolds, shuffle=True)


# TRAIN CROSSVAL MODEL

In [20]:
### GO THROUGH FOLDS
count = 0
for train_index, val_index in kf.split(y_trainval):
    
    ### GRAB TRAIN AND VALIDATION SETS
    x_train = x_trainval[train_index]
    y_train = y_trainval[train_index]
    x_val = x_trainval[val_index] 
    y_val = y_trainval[val_index]
    p_val = p_trainval[val_index]
    t_val = t_trainval[val_index]
    
    ### ADD INDEX FOR BATCH AND CHANNELS TO DATA
    x_train = x_train.reshape(x_train.shape[0], x_train.shape[1], 1)
    x_val = x_val.reshape(x_val.shape[0], x_val.shape[1], 1)
    print("\nBeginning Fold: ", count)
    print("x_train shape:", x_train.shape, "y_train shape:", y_train.shape)
    print("x_val shape:", x_val.shape, "y_val shape:", y_val.shape)

    ### SETUP MODEL
    tf.keras.backend.clear_session()
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv1D(filters=16, kernel_size=3, activation='relu', padding='same', input_shape=(200, 1)))
    model.add(tf.keras.layers.MaxPooling1D(pool_size=2))
    model.add(tf.keras.layers.Dropout(0.1))
    model.add(tf.keras.layers.Conv1D(filters=64, kernel_size=3, activation='relu', padding='same'))
    model.add(tf.keras.layers.MaxPooling1D(pool_size=2))
    model.add(tf.keras.layers.Dropout(0.1))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(32, activation='relu'))
    model.add(tf.keras.layers.Dropout(0.1))
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])

    ### TRAIN MODEL
    NUM_EPOCHS = 350
    history = model.fit(x_train, y_train, epochs=NUM_EPOCHS, batch_size=64, shuffle=True, validation_data=(x_val, y_val))
    
    ### CALCULATE METRICS FOR VALIDATION SET
    pred_val =  model.predict(x_val)
    precision, recall, _ = precision_recall_curve(y_val, pred_val)
    ap_final = average_precision_score(y_val, pred_val, average=None)
    print("Final Average Precision: ", round(ap_final, 3))
    print("Final Accuracy: ", round(history.history['val_acc'][-1], 3))
    
    ### SAVE PREDICTIONS
    PREFIX = 'cv' + str(count).zfill(2) + '_s'+ str(SEED).zfill(2) + '_i' + str(NUM_EPOCHS).zfill(3) + '_b' + str(FRAC_BALANCE) + '_'
    np.savetxt(os.path.join(OUT_DIR, PREFIX + 'predval.txt'), np.column_stack((t_val, pred_val, y_val, p_val)), fmt=['%.0f', '%.6f', '%.6f', '%.10f'], delimiter=',', header="tic,pred,gt,tpeak")

    ### PLOT RESULTS 
    ### LOSS FUNCTION
    epochs = np.arange(len(history.history['loss']))
    p1 = figure(width=450, height=300, title="LOSS")
    p1.line(epochs, history.history['val_loss'], color='lightgray', line_width=2, legend="VAL")
    p1.line(epochs, history.history['loss'], color='skyblue', line_width=2, legend="TRAIN")
    p1.legend.location = "top_right"
    ### ACCURACY METRIC
    p2 = figure(width=450, height=300, title="ACCURACY")
    p2.line(epochs, history.history['val_acc'], color='lightgray', line_width=2)
    p2.line(epochs, history.history['acc'], color='skyblue', line_width=2)
    show(row(p1, p2))
    ### RECALL
    p3 = figure(width=450, height=300, title="RECALL")
    p3.line(epochs, history.history['recall'], color='blue', line_width=2, legend="TRAIN")
    p3.line(epochs, history.history['val_recall'], color='green', line_width=2, legend="VAL")
    p3.legend.location = "bottom_right"
    ### PRECISION
    p4 = figure(width=450, height=300, title="PRECISION")
    p4.line(epochs, history.history['precision'], color='blue', line_width=2)
    p4.line(epochs, history.history['val_precision'], color='green', line_width=2)
    show(row(p3, p4))
    
    ### PR CURVE
    precision, recall, _ = precision_recall_curve(y_val, pred_val)
    ap_final = average_precision_score(y_val, pred_val, average=None)
    print("Final Average Precision: ", round(ap_final, 3))
    print("Final Accuracy: ", round(history.history['val_acc'][-1], 3))
    p = figure(plot_width=320, plot_height=300, title='Precision vs. Recall')
    p.line(precision, recall, line_width=2)
    show(p)
    
    ### CLEAN UP
    count += 1
    del model


Beginning Fold:  0
x_train shape: (17553, 200, 1) y_train shape: (17553,)
x_val shape: (4389, 200, 1) y_val shape: (4389,)
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Train on 17553 samples, validate on 4389 samples
Epoch 1/350
Epoch 2/350
Epoch 3/350
Epoch 4/350
Epoch 5/350
Epoch 6/350
Epoch 7/350
Epoch 8/350
Epoch 9/350
Epoch 10/350
Epoch 11/350
Epoch 12/350
Epoch 13/350
Epoch 14/350
Epoch 15/350
Epoch 16/350
Epoch 17/350
Epoch 18/350
Epoch 19/350
Epoch 20/350
Epoch 21/350
Epoch 22/350
Epoch 23/350
Epoch 24/350
Epoch 25/350
Epoch 26/350
Epoch 27/350
Epoch 28/350
Epoch 29/350
Epoch 30/350
Epoch 31/350
Epoch 32/350
Epoch 33/350
Epoch 34/350
Epoch 35/350
Epoch 36/350
Epoch 37/350
Epoch 38/350
Epoch 39/350
Epoch 40/350
Epoch 41/350
Epoch 42/350
Epoch 43/350
Epoch 44/350
Epoch 45/350
Epoch 46/350
Epoch 47/350
Epoch 48/350


Final Average Precision:  0.982
Final Accuracy:  0.98
