<a   href="https://colab.research.google.com/github/N-Nieto/OHBM_SEA-SIG_Educational_Course/blob/master/03_pitfalls/03_06_imbalance_on_threshold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

For questions on this notebook contact: n.nieto@fz-juelich.de

### Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, balanced_accuracy_score, roc_curve
import pandas as pd

# Set random seed for reproducibility
np.random.seed(42)

### For this experiment we will use syntetic data, to have more freedom in the imbalance percentage


### Let's systematically change the imbalance and analyse the impact on the Balanced ACC and the AUC

In [None]:
# Generate datasets with different imbalance ratios
ratios = np.arange(0.5, 0.99, 0.01) # for faster execution, use a coarser step like 0.05
results = []

for ratio in ratios:
    X, y = make_classification(n_samples=10000, n_features=20, n_informative=10,
                             n_redundant=5, weights=[ratio], flip_y=0.1,
                             random_state=42)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    model = LogisticRegression(max_iter=1000)
    model.fit(X_train, y_train)
    
    # Get predicted probabilities and decisions at different thresholds
    probas = model.predict_proba(X_test)[:, 1]
    fpr, tpr, thresholds = roc_curve(y_test, probas)
    auc = roc_auc_score(y_test, probas)
    
    # Calculate balanced accuracy at different thresholds
    bal_accs = []
    for thresh in thresholds:
        y_pred = (probas >= thresh).astype(int)
        bal_acc = balanced_accuracy_score(y_test, y_pred)
        bal_accs.append(bal_acc)
    
    # Find optimal threshold (maximizing balanced accuracy)
    optimal_idx = np.argmax(bal_accs)
    optimal_thresh = thresholds[optimal_idx]
    
    results.append({
        'ratio': ratio,
        'auc': auc,
        'optimal_thresh': optimal_thresh,
        'max_bal_acc': bal_accs[optimal_idx],
        'default_thresh_bal_acc': balanced_accuracy_score(y_test, (probas >= 0.5).astype(int))
    })

# Convert results to DataFrame for easier plotting
df = pd.DataFrame(results)

### Let's analyse now the performance difference in the Balanced ACC. 

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(df['ratio'], df['max_bal_acc'], 'o-', label='Optimal Threshold')
plt.plot(df['ratio'], df['default_thresh_bal_acc'], 'o-', label='Default Threshold (0.5)')
plt.xlabel('Class Balance Ratio (Positive Class)')
plt.ylabel('Balanced Accuracy')
plt.title('Impact of Threshold Selection on Balanced Accuracy')
plt.legend()
plt.grid(True)
plt.show()

### Let's see how the optimal threshold changes with respect of the class imbalance

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(df['ratio'], df['optimal_thresh'], 'o-', label='Optimal Threshold')
plt.axhline(0.5, color='r', linestyle='--', label='Default Threshold (0.5)')
plt.xlabel('Class Balance Ratio (Positive Class)')
plt.ylabel('Optimal Threshold')
plt.title('Optimal Threshold vs Class Imbalance')
plt.legend()
plt.grid(True)
plt.show()

## Question
What do you note in this plot? 
Is there any aproximation to the optimal threshold?

# Question
Will the imbalance also impact on a threshold-free metric like AUC?
Run the following cell and analyse the results.

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(df['ratio'], df['auc'], 'o-')
plt.xlabel('Class Balance Ratio (Positive Class)')
plt.ylabel('AUC Score')
plt.title('Impact of Class Imbalance on AUC')
plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.7, color='black')
plt.show()