In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
project = 'newfault'
traindate = '2024-10-01'
testdate = '2024-10-02'
traindate_path = '/home/sdybing/gnss-picker/cnn_models_outputs/' + project + '_fq_train/models/traindate_' + traindate + '/'
test_outputs_path = traindate_path + 'data/fakequakes_testing/classification_stats/'
figure_save_dir = traindate_path + 'figures/fakequakes_testing/classification_stats/'

fqtest_data = np.load(traindate_path + 'data/' + testdate + '_fqtest_norm_data.npy')
fqtest_metadata = np.load(traindate_path + 'data/' + testdate + '_fqtest_metadata.npy')
fqtest_target = np.load(traindate_path + 'data/' + testdate + '_fqtest_target.npy')
fqtest_predictions = np.load(traindate_path + 'data/' + testdate + '_fqtest_predictions.npy')

num_fqtest = len(fqtest_predictions)
thresholds = np.arange(0, 1.005, 0.005)
test_thresholds = [0, 0.005]

In [3]:
accuracies = []
accuracies_per = []
precisions = []
recalls = []
F1s = []

for ind in range(len(thresholds)):
    
    threshold = np.round(thresholds[ind],3)
    
    print('-------------------------------------------------------------')
    print('Threshold: ' + str(threshold))
    
    iterate = np.arange(0,num_fqtest,1)
    
    # Convert the prediction arrays to 1s and 0s if the prediction Gaussian exceeded the threshold or not
        
    pred_binary = np.zeros(len(fqtest_predictions)) # Initialize the array with all zeros
    for k in iterate:
#         plt.plot(fqtest_predictions[k]) # The output "Gaussian" or straight line prediction
#         plt.ylim(0,1)
#         plt.show()
        i = np.where(fqtest_predictions[k] >= threshold)[0]
        if len(i) == 0: 
            pred_binary[k] = 0
        elif len(i) > 0: # If anywhere in the prediction the Gaussian exceeds the threadshold, add a 1 to the pred_binary array for this prediction
            pred_binary[k] = 1
#     print('Predictions: ')
#     print(pred_binary) 
    
    # Convert the target arrays to 1s and 0s if the Gaussian exceeded the threshold or not (signal or noise)
    
    targ_binary = np.zeros(len(fqtest_target))
    for k in iterate:
        i = np.where(fqtest_target[k] > 0)[0]
        if len(i) == 0:
            targ_binary[k] = 0
        elif len(i) > 0:
            targ_binary[k] = 1
#     print('Targets: ')
#     print(targ_binary)
    
    # Calculating the accuracy, precision, recall, and F1
    
    num_preds = num_fqtest # Total number of predictions
    correct_preds = []
    wrong_preds = []
    true_pos = []
    true_neg = []
    false_pos = []
    false_neg = []
    
    for i in iterate:
        
        pred = pred_binary[i]
        targ = targ_binary[i]
        
        if pred == targ: # Add one to list of correct predictions if matching
            correct_preds.append(1)
            
            if pred == 1 and targ == 1: # True positive: there is an earthquake, and the model found it
                true_pos.append(1)
            elif pred == 0 and targ == 0: # True negative: there isn't an earthquake, and the model found just noise
                true_neg.append(1)
            
        elif pred != targ: # Add ones to list of incorrect predictions if not matching
            wrong_preds.append(1)
            
            if pred == 1 and targ == 0: # False positive: there isn't an earthquake, and the model thought it found one
                false_pos.append(1)
            elif pred == 0 and targ == 1: # False negative: there is an earthquake, and the model missed it
                false_neg.append(1)
    
    num_correct_preds = len(correct_preds)
    num_wrong_preds = len(wrong_preds)
    num_true_pos = len(true_pos)
    num_true_neg = len(true_neg)
    num_false_pos = len(false_pos)
    num_false_neg = len(false_neg)
    
    # print('Threshold: ' + str(threshold))
    # print('Correct preds: ' + str(num_correct_preds))
    # print('Wrong preds: ' + str(num_wrong_preds))
    # print('True pos: ' + str(num_true_pos))
    # print('True neg: ' + str(num_true_neg))
    # print('False pos: ' + str(num_false_pos))
    # print('False neg: ' + str(num_false_neg))
    
    accuracy = num_correct_preds / num_preds
    accuracy_per = (num_correct_preds / num_preds) * 100
    print('Accuracy: ' + str(accuracy_per) + '%')
    
    if num_true_pos == 0  and num_false_pos == 0:
        precision = 0
    else:
        precision = num_true_pos / (num_true_pos + num_false_pos)
    
    if num_true_pos == 0 and num_false_neg == 0:
        recall = 0
    else:
        recall = num_true_pos / (num_true_pos + num_false_neg)
    
    if precision + recall == 0:
        F1 = 0
    else:
        F1 = 2 * ((precision * recall) / (precision + recall))
    
    accuracies.append(accuracy)
    accuracies_per.append(accuracy_per)
    precisions.append(precision)
    recalls.append(recall)
    F1s.append(F1)

# print('Accuracies')
# print(accuracies)
# print('Precisions')
# print(precisions)
# print('Recalls')
# print(recalls)
# print('F1s')
# print(F1s)


-------------------------------------------------------------
Threshold: 0.0
Accuracy: 50.0%
-------------------------------------------------------------
Threshold: 0.005
Accuracy: 50.0%
-------------------------------------------------------------
Threshold: 0.01



KeyboardInterrupt



In [None]:
np.save(test_outputs_path + 'thresholds.npy', thresholds)
np.save(test_outputs_path + 'accuracies.npy', accuracies)
np.save(test_outputs_path + 'precisions.npy', precisions)
np.save(test_outputs_path + 'recalls.npy', recalls)
np.save(test_outputs_path + 'F1s.npy', F1s)


In [None]:
# Find threshold with highest accuracy

acc0 = 0

for idx in range(len(accuracies_per)):
    acc = accuracies_per[idx]
    if acc > acc0:
        acc0 = acc
        best_thresh = thresholds[idx] # Only updates when it hits a higher accuracy
        
print(best_thresh)

In [4]:
best_thresh = 0.135

In [5]:
thresholds = np.load(test_outputs_path + 'thresholds.npy')
accuracies = np.load(test_outputs_path + 'accuracies.npy')
precisions = np.load(test_outputs_path + 'precisions.npy')
recalls = np.load(test_outputs_path + 'recalls.npy')
F1s = np.load(test_outputs_path + 'F1s.npy')

In [6]:
r = np.where(thresholds == 0.135)[0]
print(accuracies[r])

[0.59660998]


In [7]:
e = np.where(precisions == np.max(precisions))[0]
print(np.max(precisions))
print(thresholds[e])
print(accuracies[e])

1.0
[0.835 0.84  0.845 0.85  0.855 0.86  0.865 0.87  0.875 0.88  0.885 0.89
 0.895 0.9   0.905 0.91  0.915 0.92  0.925 0.93  0.935 0.94  0.945 0.95
 0.955 0.96  0.965 0.97  0.975 0.98  0.985 0.99 ]
[0.54983649 0.54870286 0.54744931 0.54620667 0.54452801 0.54308917
 0.54159581 0.53997166 0.5381077  0.53600392 0.53417266 0.53207979
 0.52956181 0.52682581 0.52429693 0.52162634 0.51914105 0.51635056
 0.51372357 0.51123828 0.5087639  0.50650752 0.50442555 0.502954
 0.50190756 0.50116634 0.50063222 0.50029431 0.5001417  0.5000654
 0.5000327  0.5000109 ]


In [8]:
e = np.where(recalls == np.max(recalls))[0]
print(np.max(recalls))
print(thresholds[e])
print(accuracies[e])

1.0
[0.    0.005 0.01  0.015 0.02  0.025]
[0.5       0.5       0.5       0.5       0.5       0.5000218]


In [9]:
e = np.where(F1s == np.max(F1s))[0]
print(np.max(F1s))
print(thresholds[e])
print(accuracies[e])

0.6666763560258125
[0.025]
[0.5000218]


In [10]:
plt.figure(figsize = (8,5), dpi = 300)
plt.plot(thresholds, accuracies, linewidth = 2)
plt.xlabel('Threshold', fontsize = 18)
plt.ylabel('Accuracy', fontsize = 18)
plt.xlim(0,1)
plt.ylim(0,100)
plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6, label = 'Max accuracy at\nthreshold of ' + str(best_thresh))
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
plt.title('Accuracy', fontsize = 18)
plt.legend()
# plt.show();
plt.savefig(figure_save_dir + 'accuracy_by_threshold.png', format = 'PNG', facecolor = 'white')
plt.close();

plt.figure(figsize = (8,5), dpi = 300)
plt.plot(thresholds, precisions, linewidth = 2)
plt.xlabel('Threshold', fontsize = 18)
plt.ylabel('Precision', fontsize = 18)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
plt.title('Precision', fontsize = 18)
# plt.show();
plt.savefig(figure_save_dir + 'precision_by_threshold.png', format = 'PNG', facecolor = 'white')
plt.close();

plt.figure(figsize = (8,5), dpi = 300)
plt.plot(thresholds, recalls, linewidth = 2)
plt.xlabel('Threshold', fontsize = 18)
plt.ylabel('Recall', fontsize = 18)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
plt.title('Recall', fontsize = 18)
# plt.show();
plt.savefig(figure_save_dir + 'recall_by_threshold.png', format = 'PNG', facecolor = 'white')
plt.close();

plt.figure(figsize = (8,5), dpi = 300)
plt.plot(thresholds, F1s, linewidth = 2)
plt.xlabel('Threshold', fontsize = 18)
plt.ylabel('F1', fontsize = 18)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
plt.title('F1', fontsize = 18)
# plt.show();
plt.savefig(figure_save_dir + 'F1_by_threshold.png', format = 'PNG', facecolor = 'white')
plt.close();

In [11]:
# Subplot version

In [13]:
plt.figure(figsize = (8,5), dpi = 400, facecolor = 'white')

plt.subplot(221)
plt.text(x = -0.3, y = 0.9, s = '(a)', fontsize = 22)
plt.grid(lw = 0.5, zorder = 0)
plt.plot(thresholds, accuracies, linewidth = 2)
plt.ylabel('Accuracy', fontsize = 12)
plt.xlim(0,1)
plt.ylim(0,1)
plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6, label = 'Max accuracy at threshold of ' + str(best_thresh))
plt.tick_params(axis = 'both', bottom = False, labelbottom = False)
plt.yticks(fontsize = 11)
plt.legend(loc = (0.55,-1.45))

# plt.legend()

plt.subplot(222)
plt.text(x = -0.3, y = 0.9, s = '(b)', fontsize = 22)
plt.grid(lw = 0.5, zorder = 0)
plt.plot(thresholds, precisions, linewidth = 2)
plt.ylabel('Precision', fontsize = 12)
plt.xlim(0,1)
plt.ylim(0,1)
plt.tick_params(axis = 'both', bottom = False, labelbottom = False)
plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6)
plt.yticks(fontsize = 11)


plt.subplot(223)
plt.text(x = -0.3, y = 0.9, s = '(c)', fontsize = 22)
plt.grid(lw = 0.5, zorder = 0)
plt.plot(thresholds, recalls, linewidth = 2)
plt.xlabel('Threshold', fontsize = 12)
plt.ylabel('Recall', fontsize = 12)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(fontsize = 11)
plt.yticks([0, 0.2, 0.4, 0.6, 0.8], fontsize = 11)
plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6)

plt.subplot(224)
plt.text(x = -0.3, y = 0.9, s = '(d)', fontsize = 22)
plt.grid(lw = 0.5, zorder = 0)
plt.plot(thresholds, F1s, linewidth = 2)
plt.xlabel('Threshold', fontsize = 12)
plt.ylabel('F1', fontsize = 12)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(fontsize = 11)
plt.yticks([0, 0.2, 0.4, 0.6, 0.8], fontsize = 11)
plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6)

plt.subplots_adjust(hspace = 0, wspace = 0.32, bottom = 0.2)

# plt.show()

# plt.savefig('/home/sdybing/gnss-picker/manuscript_figures/fq_testdata_classification.jpg', format = 'JPG')
plt.savefig('/home/sdybing/gnss-picker/manuscript_figures/Figure_4.png', format = 'PNG')
plt.close();

In [5]:
# Subplot version with confusion matrix, no F1

In [34]:
plt.figure(figsize = (8,8), dpi = 400)

plt.subplot(221)
plt.text(x = -0.3, y = 0.9, s = '(a)', fontsize = 22)
plt.grid(lw = 0.5, zorder = 0)
plt.plot(thresholds, accuracies, linewidth = 2)
plt.xlabel('Threshold', fontsize = 12)
plt.ylabel('Accuracy', fontsize = 12)
plt.xlim(0,1)
plt.ylim(0,1)
plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6, label = 'Max accuracy at\nthreshold of ' + str(best_thresh))
# plt.tick_params(axis = 'both', bottom = False, labelbottom = False)
plt.yticks(fontsize = 11)
plt.legend(loc = (0.13,-1.67))

# plt.legend()

plt.subplot(222)
plt.text(x = -0.3, y = 0.9, s = '(b)', fontsize = 22)
plt.grid(lw = 0.5, zorder = 0)
plt.plot(thresholds, precisions, linewidth = 2)
plt.xlabel('Threshold', fontsize = 12)
plt.ylabel('Precision', fontsize = 12)
plt.xlim(0,1)
plt.ylim(0,1)
# plt.tick_params(axis = 'both', bottom = False, labelbottom = False)
plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6)
plt.yticks(fontsize = 11)

plt.subplot(223)
plt.text(x = -0.3, y = 0.9, s = '(c)', fontsize = 22)
plt.grid(lw = 0.5, zorder = 0)
plt.plot(thresholds, recalls, linewidth = 2)
plt.xlabel('Threshold', fontsize = 12)
plt.ylabel('Recall', fontsize = 12)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(fontsize = 11)
plt.yticks([0, 0.2, 0.4, 0.6, 0.8], fontsize = 11)
plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6)

# plt.subplot(224)
plt.text(x = 1.018, y = 0.9, s = '(d)', fontsize = 22)
# plt.grid(lw = 0.5, zorder = 0)
# plt.plot(thresholds, F1s, linewidth = 2)
# plt.xlabel('Threshold', fontsize = 12)
# plt.ylabel('F1', fontsize = 12)
# plt.xlim(0,1)
# plt.ylim(0,1)
# plt.xticks(fontsize = 11)
# plt.yticks([0, 0.2, 0.4, 0.6, 0.8], fontsize = 11)
# plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6)

plt.subplots_adjust(hspace = 0.25, wspace = 0.32, bottom = 0.25)

# plt.show()

plt.savefig('/home/sdybing/gnss-picker/manuscript_figures/fq_testdata_classification_blankd.png', format = 'PNG')
plt.close();