## PR Curves to find the optimal threshold

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import precision_recall_curve, f1_score, accuracy_score
import matplotlib.pyplot as plt

# Load the CSV file
df = pd.read_csv('tile_df.csv')

# Parse the probabilities column
df['prob_class0'] = df['Probability'].apply(lambda x: float(x.split()[0].strip('[]')))
df['prob_class1'] = df['Probability'].apply(lambda x: float(x.split()[1].strip('[]')))

# True labels
y_true = df['Label']

# Probabilities of class 0
prob_class0 = df['prob_class0']

# Calculate precision-recall curve
precision, recall, thresholds = precision_recall_curve(y_true, prob_class0, pos_label=0)

# Add a zero threshold to match lengths
thresholds = np.append(thresholds, 1)

# Avoid division by zero
f1_scores = np.where((precision + recall) == 0, 0, 2 * (precision * recall) / (precision + recall))

# Find the index of the maximum F1 score
optimal_idx = np.argmax(f1_scores)

# Optimal threshold
optimal_threshold = thresholds[optimal_idx]
optimal_precision = precision[optimal_idx]
optimal_recall = recall[optimal_idx]
optimal_f1 = f1_scores[optimal_idx]

# Calculate accuracy for each threshold
accuracies = []
for threshold in thresholds:
    predictions = prob_class0 >= threshold
    accuracies.append((predictions == y_true).mean())

# Convert list to numpy array
accuracies = np.array(accuracies)

# Optimal accuracy
optimal_accuracy = accuracies[optimal_idx]

# Filter the data to plot only within the desired threshold range
min_threshold = 0.2
max_threshold = 0.8
filtered_indices = (thresholds >= min_threshold) & (thresholds <= max_threshold)

filtered_thresholds = thresholds[filtered_indices]
filtered_precision = precision[filtered_indices]
filtered_recall = recall[filtered_indices]
filtered_accuracies = accuracies[filtered_indices]
filtered_f1_scores = f1_scores[filtered_indices]

# Print the results
print(f'Optimal Threshold: {optimal_threshold}')
print(f'Optimal Precision: {optimal_precision}')
print(f'Optimal Recall: {optimal_recall}')
print(f'Optimal F1 Score: {optimal_f1}')
print(f'Optimal Accuracy: {optimal_accuracy}')

# Plot Precision, Recall, Accuracy, and F1 score within the specified threshold range
plt.figure(figsize=(10, 8))
plt.plot(filtered_thresholds, filtered_precision, marker='.', markersize=4, label='Precision')
plt.plot(filtered_thresholds, filtered_recall, marker='.', markersize=4, label='Recall')
plt.plot(filtered_thresholds, filtered_accuracies, marker='.', markersize=4, label='Accuracy')
plt.plot(filtered_thresholds, filtered_f1_scores, marker='.', markersize=4, label='F1 Score')
plt.scatter([optimal_threshold], [optimal_f1], marker='o', color='red', label='Optimal Threshold (F1)')
plt.xlabel('Threshold')
plt.ylabel('Score')
plt.title('Precision, Recall, Accuracy, and F1 Score vs. Threshold')
plt.legend()
plt.grid(True)
plt.show()