In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
project = 'newfault'
traindate = '2024-10-01'
testdate = '2024-10-02'
traindate_path = '/home/sdybing/gnss-picker/cnn_models_outputs/' + project + '_fq_train/models/traindate_' + traindate + '/'
test_outputs_path = traindate_path + 'data/fakequakes_testing/classification_stats/'
figure_save_dir = traindate_path + 'figures/fakequakes_testing/classification_stats/'

fqtest_data = np.load(traindate_path + 'data/' + testdate + '_fqtest_norm_data.npy')
fqtest_metadata = np.load(traindate_path + 'data/' + testdate + '_fqtest_metadata.npy')
fqtest_target = np.load(traindate_path + 'data/' + testdate + '_fqtest_target.npy')
fqtest_predictions = np.load(traindate_path + 'data/' + testdate + '_fqtest_predictions.npy')

num_fqtest = len(fqtest_predictions)
thresholds = np.arange(0, 1.005, 0.005)
test_thresholds = [0, 0.005]

In [3]:
accuracies = []
accuracies_per = []
precisions = []
recalls = []
F1s = []

for ind in range(len(thresholds)):
    
    threshold = np.round(thresholds[ind],3)
    
    print('-------------------------------------------------------------')
    print('Threshold: ' + str(threshold))
    
    iterate = np.arange(0,num_fqtest,1)
    
    # Convert the prediction arrays to 1s and 0s if the prediction Gaussian exceeded the threshold or not
        
    pred_binary = np.zeros(len(fqtest_predictions)) # Initialize the array with all zeros
    for k in iterate:
#         plt.plot(fqtest_predictions[k]) # The output "Gaussian" or straight line prediction
#         plt.ylim(0,1)
#         plt.show()
        i = np.where(fqtest_predictions[k] >= threshold)[0]
        if len(i) == 0: 
            pred_binary[k] = 0
        elif len(i) > 0: # If anywhere in the prediction the Gaussian exceeds the threadshold, add a 1 to the pred_binary array for this prediction
            pred_binary[k] = 1
#     print('Predictions: ')
#     print(pred_binary) 
    
    # Convert the target arrays to 1s and 0s if the Gaussian exceeded the threshold or not (signal or noise)
    
    targ_binary = np.zeros(len(fqtest_target))
    for k in iterate:
        i = np.where(fqtest_target[k] > 0)[0]
        if len(i) == 0:
            targ_binary[k] = 0
        elif len(i) > 0:
            targ_binary[k] = 1
#     print('Targets: ')
#     print(targ_binary)
    
    # Calculating the accuracy, precision, recall, and F1
    
    num_preds = num_fqtest # Total number of predictions
    correct_preds = []
    wrong_preds = []
    true_pos = []
    true_neg = []
    false_pos = []
    false_neg = []
    
    for i in iterate:
        
        pred = pred_binary[i]
        targ = targ_binary[i]
        
        if pred == targ: # Add one to list of correct predictions if matching
            correct_preds.append(1)
            
            if pred == 1 and targ == 1: # True positive: there is an earthquake, and the model found it
                true_pos.append(1)
            elif pred == 0 and targ == 0: # True negative: there isn't an earthquake, and the model found just noise
                true_neg.append(1)
            
        elif pred != targ: # Add ones to list of incorrect predictions if not matching
            wrong_preds.append(1)
            
            if pred == 1 and targ == 0: # False positive: there isn't an earthquake, and the model thought it found one
                false_pos.append(1)
            elif pred == 0 and targ == 1: # False negative: there is an earthquake, and the model missed it
                false_neg.append(1)
    
    num_correct_preds = len(correct_preds)
    num_wrong_preds = len(wrong_preds)
    num_true_pos = len(true_pos)
    num_true_neg = len(true_neg)
    num_false_pos = len(false_pos)
    num_false_neg = len(false_neg)
    
    # print('Threshold: ' + str(threshold))
    # print('Correct preds: ' + str(num_correct_preds))
    # print('Wrong preds: ' + str(num_wrong_preds))
    # print('True pos: ' + str(num_true_pos))
    # print('True neg: ' + str(num_true_neg))
    # print('False pos: ' + str(num_false_pos))
    # print('False neg: ' + str(num_false_neg))
    
    accuracy = num_correct_preds / num_preds
    accuracy_per = (num_correct_preds / num_preds) * 100
    print('Accuracy: ' + str(accuracy_per) + '%')
    
    if num_true_pos == 0  and num_false_pos == 0:
        precision = 0
    else:
        precision = num_true_pos / (num_true_pos + num_false_pos)
    
    if num_true_pos == 0 and num_false_neg == 0:
        recall = 0
    else:
        recall = num_true_pos / (num_true_pos + num_false_neg)
    
    if precision + recall == 0:
        F1 = 0
    else:
        F1 = 2 * ((precision * recall) / (precision + recall))
    
    accuracies.append(accuracy)
    accuracies_per.append(accuracy_per)
    precisions.append(precision)
    recalls.append(recall)
    F1s.append(F1)

# print('Accuracies')
# print(accuracies)
# print('Precisions')
# print(precisions)
# print('Recalls')
# print(recalls)
# print('F1s')
# print(F1s)


-------------------------------------------------------------
Threshold: 0.0
Accuracy: 50.0%
-------------------------------------------------------------
Threshold: 0.005
Accuracy: 50.0%
-------------------------------------------------------------
Threshold: 0.01
Accuracy: 50.0%
-------------------------------------------------------------
Threshold: 0.015
Accuracy: 50.0%
-------------------------------------------------------------
Threshold: 0.02
Accuracy: 50.0%
-------------------------------------------------------------
Threshold: 0.025
Accuracy: 50.00218007412251%
-------------------------------------------------------------
Threshold: 0.03
Accuracy: 50.33682145192937%
-------------------------------------------------------------
Threshold: 0.035
Accuracy: 52.19533464137781%
-------------------------------------------------------------
Threshold: 0.04
Accuracy: 54.43645083932853%
-------------------------------------------------------------
Threshold: 0.045
Accuracy: 56.0922171

In [4]:
np.save(test_outputs_path + 'thresholds.npy', thresholds)
np.save(test_outputs_path + 'accuracies.npy', accuracies)
np.save(test_outputs_path + 'precisions.npy', precisions)
np.save(test_outputs_path + 'recalls.npy', recalls)
np.save(test_outputs_path + 'F1s.npy', F1s)


In [5]:
# Find threshold with highest accuracy

acc0 = 0

for idx in range(len(accuracies_per)):
    acc = accuracies_per[idx]
    if acc > acc0:
        acc0 = acc
        best_thresh = thresholds[idx] # Only updates when it hits a higher accuracy
        
print(best_thresh)

0.135


In [7]:
plt.figure(figsize = (8,5), dpi = 300)
plt.plot(thresholds, accuracies_per, linewidth = 2)
plt.xlabel('Threshold', fontsize = 18)
plt.ylabel('Accuracy (%)', fontsize = 18)
plt.xlim(0,1)
plt.ylim(0,100)
plt.axvline(best_thresh, color = 'red', linestyle = '--', alpha = 0.6, label = 'Max accuracy at\nthreshold of ' + str(best_thresh))
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
plt.title('Accuracy', fontsize = 18)
plt.legend()
# plt.show();
plt.savefig(figure_save_dir + 'accuracy_by_threshold.png', format = 'PNG', facecolor = 'white')
plt.close();

plt.figure(figsize = (8,5), dpi = 300)
plt.plot(thresholds, precisions, linewidth = 2)
plt.xlabel('Threshold', fontsize = 18)
plt.ylabel('Precision', fontsize = 18)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
plt.title('Precision', fontsize = 18)
# plt.show();
plt.savefig(figure_save_dir + 'precision_by_threshold.png', format = 'PNG', facecolor = 'white')
plt.close();

plt.figure(figsize = (8,5), dpi = 300)
plt.plot(thresholds, recalls, linewidth = 2)
plt.xlabel('Threshold', fontsize = 18)
plt.ylabel('Recall', fontsize = 18)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
plt.title('Recall', fontsize = 18)
# plt.show();
plt.savefig(figure_save_dir + 'recall_by_threshold.png', format = 'PNG', facecolor = 'white')
plt.close();

plt.figure(figsize = (8,5), dpi = 300)
plt.plot(thresholds, F1s, linewidth = 2)
plt.xlabel('Threshold', fontsize = 18)
plt.ylabel('F1', fontsize = 18)
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
plt.title('F1', fontsize = 18)
# plt.show();
plt.savefig(figure_save_dir + 'F1_by_threshold.png', format = 'PNG', facecolor = 'white')
plt.close();