# Imbalance Handling Techniques Analysis for Readmission Prediction

This notebook analyzes and compares different techniques for handling class imbalance in the readmission prediction model:

1. Baseline (no imbalance handling)
2. Class weights (class_weight='balanced')
3. Random oversampling
4. SMOTE (Synthetic Minority Over-sampling Technique)
5. Random undersampling

For each technique, we generate:
- PR curves
- F1 scores
- Precision
- Recall
- PR AUC

All metrics are computed using cross-validation to ensure robust evaluation.

In [None]:
# Import necessary libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional, Tuple, Any

# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

# Import project modules
from src.utils import get_logger, load_config, get_data_path, get_project_root
from src.models.model import ReadmissionModel
from src.models.imbalance_analysis import (
    load_data, preprocess_data, create_imbalance_pipelines,
    evaluate_pipelines, plot_pr_curves, plot_metrics_comparison,
    save_results_to_csv, analyze_imbalance_techniques
)

# Set up logging
logger = get_logger(__name__)

# Set plot style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

## 1. Load and Explore Data

First, let's load the data and explore the class distribution to understand the imbalance.

In [None]:
# Load data
data = load_data()
print(f"Loaded data with {data.shape[0]} rows and {data.shape[1]} columns")

# Display first few rows
data.head()

In [None]:
# Preprocess data
X, y = preprocess_data(data)
print(f"Preprocessed features shape: {X.shape}")

# Visualize class distribution
class_counts = y.value_counts()
plt.figure(figsize=(10, 6))
ax = sns.barplot(x=class_counts.index, y=class_counts.values)
plt.title('Class Distribution for 30-day Readmission')
plt.xlabel('Readmission')
plt.ylabel('Count')

# Add percentage labels
total = len(y)
for i, count in enumerate(class_counts.values):
    percentage = count / total * 100
    ax.text(i, count + 5, f"{count} ({percentage:.1f}%)", ha='center')

plt.show()

# Print imbalance ratio
imbalance_ratio = class_counts.iloc[0] / class_counts.iloc[1] if class_counts.iloc[0] > class_counts.iloc[1] else class_counts.iloc[1] / class_counts.iloc[0]
print(f"Imbalance ratio: {imbalance_ratio:.2f}:1")

## 2. Implement and Evaluate Imbalance Handling Techniques

Now, let's implement and evaluate different techniques for handling the class imbalance.

In [None]:
# Create pipelines for different techniques
pipelines = create_imbalance_pipelines(random_state=42)
print(f"Created {len(pipelines)} pipelines for evaluation:")
for name in pipelines.keys():
    print(f"- {name}")

In [None]:
# Evaluate pipelines using cross-validation
results = evaluate_pipelines(X, y, pipelines, cv_folds=5, random_state=42)

# Create a summary dataframe
summary = []
for name, result in results.items():
    summary.append({
        "Technique": name,
        "Precision": result["precision"],
        "Recall": result["recall"],
        "F1 Score": result["f1"],
        "PR AUC": result["pr_auc"]
    })

summary_df = pd.DataFrame(summary)
summary_df.set_index("Technique", inplace=True)
summary_df.style.highlight_max(axis=0, color='lightgreen').format("{:.4f}")

## 3. Visualize Results

### 3.1 Precision-Recall Curves

PR curves are particularly useful for imbalanced classification problems as they focus on the minority class performance.

In [None]:
# Plot PR curves
plt.figure(figsize=(12, 8))

for name, result in results.items():
    if result["precision_curve"] is not None and result["recall_curve"] is not None:
        plt.plot(
            result["recall_curve"], 
            result["precision_curve"],
            label=f"{name} (PR AUC = {result['pr_auc']:.3f})"
        )

plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curves for Different Imbalance Handling Techniques")
plt.legend(loc="best")
plt.grid(True)
plt.show()

### 3.2 Metrics Comparison

Let's compare the key metrics (Precision, Recall, F1, PR AUC) across all techniques.

In [None]:
# Plot metrics comparison
metrics = ["precision", "recall", "f1", "pr_auc"]
pipeline_names = list(results.keys())

# Extract metrics for each pipeline
metric_values = {metric: [results[name][metric] for name in pipeline_names] for metric in metrics}

# Create the plot
fig, ax = plt.subplots(figsize=(14, 8))

x = np.arange(len(pipeline_names))
width = 0.2
multiplier = 0

for metric, values in metric_values.items():
    offset = width * multiplier
    rects = ax.bar(x + offset, values, width, label=metric.upper())
    ax.bar_label(rects, fmt="{:.2f}", padding=3)
    multiplier += 1

ax.set_ylabel("Score")
ax.set_title("Comparison of Metrics Across Imbalance Handling Techniques")
ax.set_xticks(x + width, pipeline_names)
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
ax.set_ylim(0, 1)

plt.tight_layout()
plt.show()

### 3.3 Heatmap of Metrics

A heatmap can provide a clear visual comparison of all metrics across techniques.

In [None]:
# Create heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(summary_df, annot=True, cmap="YlGnBu", fmt=".3f", linewidths=.5)
plt.title("Heatmap of Evaluation Metrics Across Imbalance Handling Techniques")
plt.tight_layout()
plt.show()

## 4. Discussion of Trade-offs

Let's analyze the trade-offs between different imbalance handling techniques.

### 4.1 Baseline vs. Class Weights

Class weights adjust the importance of each class during training, which can help the model pay more attention to the minority class without changing the data. This typically improves recall at the expense of precision.

Let's compare the metrics:

In [None]:
# Compare Baseline vs. Class Weights
comparison_df = summary_df.loc[["Baseline", "Class Weights"]]
comparison_df.style.highlight_max(axis=0, color='lightgreen').format("{:.4f}")

### 4.2 Random Oversampling vs. SMOTE

Random oversampling duplicates existing minority samples, which can lead to overfitting as the model sees the exact same minority samples multiple times.

SMOTE creates synthetic examples by interpolating between existing minority samples, which can help the model generalize better by learning from a more diverse set of minority class examples.

SMOTE may perform differently than random oversampling because it creates new, synthetic samples rather than just duplicating existing ones, potentially leading to better generalization but possibly introducing noise if the synthetic samples are not representative of the true data distribution.

In [None]:
# Compare Random Oversampling vs. SMOTE
comparison_df = summary_df.loc[["Random Oversampling", "SMOTE"]]
comparison_df.style.highlight_max(axis=0, color='lightgreen').format("{:.4f}")

### 4.3 Oversampling vs. Undersampling

Oversampling techniques (Random Oversampling, SMOTE) increase the number of minority class samples to balance the classes, preserving all available information but potentially leading to longer training times and overfitting.

Undersampling reduces the number of majority class samples, which can lead to information loss but may help prevent the model from being biased towards the majority class and can reduce training time.

In [None]:
# Compare Oversampling vs. Undersampling
comparison_df = summary_df.loc[["Random Oversampling", "SMOTE", "Random Undersampling"]]
comparison_df.style.highlight_max(axis=0, color='lightgreen').format("{:.4f}")

## 5. Conclusion

### Key Findings

Based on the analysis above, we can draw the following conclusions:

1. **Best Overall Technique**: [To be filled based on actual results]
2. **Precision-Recall Trade-off**: [To be filled based on actual results]
3. **SMOTE vs. Random Oversampling**: [To be filled based on actual results]

### Limitations

It's important to note the limitations of this analysis:

1. **Small Dataset Size**: With only ~200 demo patients, the absolute performance metrics may be unstable and not generalizable. The relative differences between techniques are more informative than the absolute values.

2. **Cross-validation Stability**: Even with cross-validation, the small dataset size means that the results may vary significantly depending on the random splits.

3. **Model Simplicity**: We used logistic regression for all techniques to focus on the imbalance handling methods, but more complex models might interact differently with these techniques.

### Next Steps

For a more comprehensive analysis, consider:

1. Testing these techniques on the full MIMIC dataset when available
2. Exploring combinations of techniques (e.g., SMOTE + class weights)
3. Trying different base classifiers (e.g., random forest, XGBoost)
4. Implementing more advanced techniques like ADASYN, SMOTETomek, or SMOTEENN
5. Exploring threshold optimization for each technique