# **Class Imbalance Handling for CTZ & GEN**

**`Hypothesis:`** GEN/CTZ underperformance due to severe class imbalance (23%/34% resistant)
Test SMOTE, class weights, and threshold optimization

In [10]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
from imblearn.over_sampling import SMOTE
from imblearn.combine import SMOTETomek
from collections import Counter

## **LOAD DATA**

In [None]:
data = pd.read_csv("/content/drive/MyDrive/ML-iAMR_Recreation/01_data/raw/giessen/cip_ctx_ctz_gen_multi_data.csv")
pheno = pd.read_csv("/content/drive/MyDrive/ML-iAMR_Recreation/01_data/raw/giessen/cip_ctx_ctz_gen_pheno.csv", index_col=0)
X = data.drop('prename', axis=1).values

In [11]:
#EXPERIMENT CONFIG
ANTIBIOTICS = ['CTZ', 'GEN']  # Focus on problematic ones
STRATEGIES = {
    'baseline': {'use_smote': False, 'class_weight': None},
    'class_weight': {'use_smote': False, 'class_weight': 'balanced'},
    'smote': {'use_smote': True, 'class_weight': None},
    'smote_balanced': {'use_smote': True, 'class_weight': 'balanced'}
}

In [12]:
results = []

for ab in ANTIBIOTICS:
    print(f"\n{'='*60}")
    print(f"TESTING: {ab}")
    print(f"{'='*60}")

    y = pheno[ab].values
    print(f"Original distribution: {Counter(y)}")
    print(f"Imbalance ratio: {y.sum()}/{len(y)} = {y.sum()/len(y):.2%} resistant")

    for strategy_name, config in STRATEGIES.items():
        print(f"\n--- Strategy: {strategy_name} ---")

        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        fold_aucs = []
        fold_f1s = []

        for fold, (train_idx, val_idx) in enumerate(cv.split(X, y), 1):
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]

            #apply SMOTE if configured
            if config['use_smote']:
                smote = SMOTE(random_state=42, k_neighbors=3)
                X_train, y_train = smote.fit_resample(X_train, y_train)
                if fold == 1:  #print once
                    print(f"  After SMOTE: {Counter(y_train)}")

            #train RF
            rf = RandomForestClassifier(
                n_estimators=200,
                class_weight=config['class_weight'],
                random_state=42,
                n_jobs=-1
            )
            rf.fit(X_train, y_train)

            #evaluate
            y_pred_proba = rf.predict_proba(X_val)[:, 1]
            y_pred_class = rf.predict(X_val)

            auc = roc_auc_score(y_val, y_pred_proba)
            f1 = f1_score(y_val, y_pred_class)

            fold_aucs.append(auc)
            fold_f1s.append(f1)

        mean_auc = np.mean(fold_aucs)
        std_auc = np.std(fold_aucs)
        mean_f1 = np.mean(fold_f1s)

        #compare to baseline
        if strategy_name == 'baseline':
            baseline_auc = mean_auc
            improvement = 0
        else:
            improvement = mean_auc - baseline_auc

        result = {
            'Antibiotic': ab,
            'Strategy': strategy_name,
            'AUC_Mean': round(mean_auc, 4),
            'AUC_Std': round(std_auc, 4),
            'F1_Mean': round(mean_f1, 4),
            'Improvement_vs_Baseline': round(improvement, 4),
            'Status': 'Succeeded' if improvement > 0.01 else ('Acceptable' if improvement > 0 else 'Failed')
        }
        results.append(result)

        print(f"  AUC: {mean_auc:.4f}±{std_auc:.4f} | F1: {mean_f1:.4f} | Δ: {improvement:+.4f} {result['Status']}")


TESTING: CTZ
Original distribution: Counter({np.int64(0): 533, np.int64(1): 276})
Imbalance ratio: 276/809 = 34.12% resistant

--- Strategy: baseline ---
  AUC: 0.8453±0.0410 | F1: 0.7123 | Δ: +0.0000 Failed

--- Strategy: class_weight ---
  AUC: 0.8441±0.0432 | F1: 0.7156 | Δ: -0.0012 Failed

--- Strategy: smote ---
  After SMOTE: Counter({np.int64(0): 426, np.int64(1): 426})
  AUC: 0.8432±0.0404 | F1: 0.7195 | Δ: -0.0021 Failed

--- Strategy: smote_balanced ---
  After SMOTE: Counter({np.int64(0): 426, np.int64(1): 426})
  AUC: 0.8432±0.0404 | F1: 0.7195 | Δ: -0.0021 Failed

TESTING: GEN
Original distribution: Counter({np.int64(0): 621, np.int64(1): 188})
Imbalance ratio: 188/809 = 23.24% resistant

--- Strategy: baseline ---
  AUC: 0.7656±0.0302 | F1: 0.4702 | Δ: +0.0000 Failed

--- Strategy: class_weight ---
  AUC: 0.7771±0.0269 | F1: 0.4533 | Δ: +0.0115 Succeeded

--- Strategy: smote ---
  After SMOTE: Counter({np.int64(0): 496, np.int64(1): 496})
  AUC: 0.7638±0.0284 | F1: 0.496

## **ANALYSIS**

In [13]:
results_df = pd.DataFrame(results)

print("\n" + "="*80)
print("CLASS IMBALANCE HANDLING RESULTS")
print("="*80)
print(results_df.to_string(index=False))

#find best strategy per antibiotic
print("\n" + "="*80)
print("BEST STRATEGIES")
print("="*80)
for ab in ANTIBIOTICS:
    ab_results = results_df[results_df['Antibiotic'] == ab]
    best = ab_results.loc[ab_results['AUC_Mean'].idxmax()]
    print(f"{ab}: {best['Strategy']} → AUC={best['AUC_Mean']:.4f} (Δ={best['Improvement_vs_Baseline']:+.4f})")

# Save results
results_df.to_csv("/content/drive/MyDrive/ML-iAMR_Recreation/05_evaluation/results/EXP-005_imbalance_results.csv", index=False)
print("\nResults saved to results/EXP-005_imbalance_results.csv")


CLASS IMBALANCE HANDLING RESULTS
Antibiotic       Strategy  AUC_Mean  AUC_Std  F1_Mean  Improvement_vs_Baseline    Status
       CTZ       baseline    0.8453   0.0410   0.7123                   0.0000    Failed
       CTZ   class_weight    0.8441   0.0432   0.7156                  -0.0012    Failed
       CTZ          smote    0.8432   0.0404   0.7195                  -0.0021    Failed
       CTZ smote_balanced    0.8432   0.0404   0.7195                  -0.0021    Failed
       GEN       baseline    0.7656   0.0302   0.4702                   0.0000    Failed
       GEN   class_weight    0.7771   0.0269   0.4533                   0.0115 Succeeded
       GEN          smote    0.7638   0.0284   0.4963                  -0.0018    Failed
       GEN smote_balanced    0.7638   0.0284   0.4963                  -0.0018    Failed

BEST STRATEGIES
CTZ: baseline → AUC=0.8453 (Δ=+0.0000)
GEN: class_weight → AUC=0.7771 (Δ=+0.0115)

Results saved to results/EXP-005_imbalance_results.csv


In [14]:
#RECOMMENDATIONS
print("\n" + "="*80)
print("RECOMMENDATIONS")
print("="*80)

for ab in ANTIBIOTICS:
    ab_results = results_df[results_df['Antibiotic'] == ab]
    best_strategy = ab_results.loc[ab_results['AUC_Mean'].idxmax(), 'Strategy']
    improvement = ab_results.loc[ab_results['AUC_Mean'].idxmax(), 'Improvement_vs_Baseline']

    if improvement > 0.02:
        print(f" {ab}: Use '{best_strategy}' strategy (significant improvement: +{improvement:.4f})")
    elif improvement > 0.01:
        print(f" {ab}: Consider '{best_strategy}' (marginal improvement: +{improvement:.4f})")
    else:
        print(f" {ab}: Class imbalance NOT the primary issue. Investigate feature quality/hyperparameters.")


RECOMMENDATIONS
 CTZ: Class imbalance NOT the primary issue. Investigate feature quality/hyperparameters.
 GEN: Consider 'class_weight' (marginal improvement: +0.0115)
