In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

project = 'newfault'
traindate = '2024-10-01'
testdate = '2024-10-02'
traindate_path = '/home/sdybing/gnss-picker/cnn_models_outputs/' + project + '_fq_train/models/traindate_' + traindate + '/'
test_outputs_path = traindate_path + 'data/'
figure_save_dir = traindate_path + 'figures/'






In [2]:
fqtest_data = np.load(test_outputs_path + testdate + '_fqtest_norm_data.npy')
fqtest_metadata = np.load(test_outputs_path + testdate + '_fqtest_metadata.npy')
fqtest_target = np.load(test_outputs_path + testdate + '_fqtest_target.npy')
fqtest_predictions = np.load(test_outputs_path + testdate + '_fqtest_predictions.npy')

In [3]:
num_fqtest = len(fqtest_predictions)
best_thresh = 0.025 # From code 2

zeros = np.zeros((fqtest_predictions.shape[0],1))
analysis_array = np.c_[fqtest_metadata, zeros] # Adds a column of zeros to the metadata array to initialize

In [4]:
results = np.load(test_outputs_path + 'fakequakes_testing/fqtest_metadata_withresults_0025.npy')

In [5]:
results[:,3]

array(['true pos', 'true pos', 'false pos', ..., 'true pos', 'true pos',
       'true pos'], dtype='<U32')

In [7]:
threshs = ['0025', '0135', '0835']

for thresh in threshs:
    
    if thresh == '0135':
        results = np.load(test_outputs_path + 'fakequakes_testing/fqtest_metadata_withresults.npy')
    else:
        results = np.load(test_outputs_path + 'fakequakes_testing/fqtest_metadata_withresults_' + thresh + '.npy')
    
    tp = 0
    tn = 0
    fp = 0
    fn = 0

    for i in range(len(results)):

        result = results[i,3]
    #     print(result)
        if result == 'true pos':
            tp += 1
        elif result == 'true neg':
            tn += 1
        elif result == 'false pos':
            fp += 1
        elif result == 'false neg':
            fn += 1
        else:
            print('Error')

    # Define confusion matrix values
    # Format: [[TN, FP], [FN, TP]]
    conf_matrix = np.array([[tn, fp],   # TN, FP
                            [fn, tp]])  # FN, TP
    conf_matrix_names = np.array([['TN', 'FP'],   # TN, FP
                            ['FN', 'TP']])

    # Create a figure and a single subplot
    fig, ax = plt.subplots(dpi = 400)

    # Plot the matrix with transposed axes (True on x-axis, Predicted on y-axis)
    cax = ax.imshow(conf_matrix.T, cmap='Blues', vmin = 0, vmax = 92000)
    ax.set_xticks(np.arange(-0.5, conf_matrix.shape[0], 1), minor=True)
    ax.set_yticks(np.arange(-0.5, conf_matrix.shape[1], 1), minor=True)
    ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5)

    # Turn off the ticks (optional, for cleaner look)
    ax.tick_params(which='minor', bottom=False, left=False)

    # Add colorbar for scale
    cbar = plt.colorbar(cax, label = 'Count')
    
    # Use scientific notation for the colorbar
    cbar.ax.yaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))
    cbar.ax.yaxis.get_major_formatter().set_scientific(True)
    cbar.ax.yaxis.get_major_formatter().set_powerlimits((0, 0))  # Forces scientific notation

    # Define class names
    classes = ['Noise', 'Earthquake']

    # Add labels for x and y axes (swapped now)
    ax.set_xticks(np.arange(len(classes)))
    ax.set_yticks(np.arange(len(classes)))
    ax.set_xticklabels(classes, fontsize = 12)
    ax.set_yticklabels(classes, fontsize = 12, rotation = 90, va = 'center')

    # Add axis labels (swapped)
    plt.xlabel('Actual label', fontsize = 14)
    plt.ylabel('Predicted label', fontsize = 14)

    # Annotate each cell with the TP, TN, FP, FN values
    for i in range(conf_matrix.shape[0]):
        for j in range(conf_matrix.shape[1]):
            if i == 1 and j == 0 or i == 0 and j == 1:
                color = 'tomato'
            else:
                color = 'black'
            ax.text(i, j-0.18, f'{conf_matrix_names[i, j]}', ha='center', va='center', color=color, fontsize=24)
            ax.text(i, j, f'{round((conf_matrix[i, j]/len(results))*100,1)}%', ha='center', va='center', color=color, fontsize=20)
            ax.text(i, j+0.15, f'{conf_matrix[i, j]}', ha='center', va='center', color=color, fontsize=14)

    # Display the plot
#     plt.show()

    plt.savefig('/home/sdybing/gnss-picker/manuscript_figures/fq_testdata_confusion_matrix_' + thresh + '.jpg', format = 'JPG')
    plt.close();

In [12]:
print(tp)
print(tn)
print(fp)
print(fn)

4572
45870
0
41298
