In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import seaborn as sns
from foundry.transforms import Dataset

### Plotting Functions

In [None]:
# Precision-Recall Curve
def plot_precision_recall_curve(pr_curve_df, figsize=(10, 6)):
    """
    Plot precision-recall curve from a PySpark DF
    
    Parameters:
    pr_curve_df - DF with precision, recall and no_skill_precision columns
    figsize - Figure size tuple (default: (10, 6))
    """
    # Convert DF to Pandas for plotting
    # pr_data = pr_curve_df.toPandas()
    
    # Get no-skill baseline value
    no_skill = pr_curve_df['no_skill_precision'].iloc[0]
    
    # Plot precision-recall curve
    plt.figure(figsize=figsize)
    plt.plot(pr_curve_df['recall'], pr_curve_df['precision'], 
             marker='.', linestyle='-', color='blue', label='Model')
    
    # Plot no-skill baseline
    plt.plot([0, 1], [no_skill, no_skill], 
             linestyle='--', color='red', label='No Skill')
    
    # Add labels and title
    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curve', fontsize=14)
    
    # Add grid and legend
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(fontsize=10)
    
    # Set axis limits
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    
    # Add a text with the avg precision score
    # Calculate the avg area under PR curve
    sorted_data = pr_curve_df.sort_values('recall')
    ap = np.trapezoid(sorted_data['precision'], sorted_data['recall'])
    plt.text(0.05, 0.05, f'Average Precision: {ap:.4f}', 
             transform=plt.gca().transAxes, fontsize=10, bbox=dict(facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()
    
    return plt

In [None]:
# Confusion Matrix
def plot_confusion_matrix(confusion_df, normalize=False, title='Confusion Matrix', cmap='Blues', figsize=(10, 8)):
    """
    Plot confusion matrix from a PySpark DF containing confusion matrix data.
    
    Parameters:
    confusion_df - DF with columns: actual, predicted, count
    normalize - Whether to normalize by row (default: False)
    title - Plot title (default: 'Confusion Matrix')
    cmap - Colormap to use (default: 'Blues')
    figsize - Figure size (default: (10, 8))
    
    Returns:
    matplotlib fig
    """
    # Pivot data to create matrix format
    matrix = confusion_df.pivot(index='actual', columns='predicted', values='count')
    
    # Create a 2x2 confusion matrix array - fill missing vals with 0
    classes = sorted(set(confusion_df['actual'].unique()).union(set(confusion_df['predicted'].unique())))
    cm_array = np.zeros((len(classes), len(classes)))
    
    for i, actual in enumerate(classes):
        for j, predicted in enumerate(classes):
            matching = confusion_df[(confusion_df['actual'] == actual) & 
                                (confusion_df['predicted'] == predicted)]
            if not matching.empty:
                cm_array[i, j] = matching['count'].iloc[0]
    
    # Normalize if requested
    if normalize:
        row_sums = cm_array.sum(axis=1)
        cm_array = cm_array / row_sums[:, np.newaxis]
        fmt = '.2f'
    else:
        # Convert to ints
        if np.all(np.mod(cm_array, 1) == 0):
            cm_array = cm_array.astype(int)
            fmt = 'd'
        else:
            fmt = '.1f'
    
    # Create fig
    plt.figure(figsize=figsize)
    
    # Create heatmap
    sns.heatmap(cm_array, annot=True, fmt=fmt, cmap=cmap, cbar=False,
                xticklabels=classes, yticklabels=classes)
    
    # Set labels
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title(title)
    
    # Add metrics to the plot
    if cm_array.shape == (2, 2):
        tn, fp = cm_array[0, 0], cm_array[0, 1]
        fn, tp = cm_array[1, 0], cm_array[1, 1]
        
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        metrics_text = (f"Accuracy: {accuracy:.3f}\n"
                       f"Precision: {precision:.3f}\n"
                       f"Recall: {recall:.3f}\n"
                       f"F1 Score: {f1:.3f}")
        
        plt.figtext(1.05, 0.5, metrics_text, ha='left', fontsize=12)
    
    plt.tight_layout()
    
    return plt

In [None]:
# ROC Curve
def plot_roc_curve(roc_df, title='ROC Curve', figsize=(10, 8)):
    """
    Plot ROC curve from a PySpark DF containing ROC data.
    
    Parameters:
    roc_df - PySpark DF with columns: fpr, tpr, threshold, auc
    title - Plot title (default: 'ROC Curve')
    figsize - Figure size (default: (10, 8))
    
    Returns:
    matplotlib fig
    """
    # Get AUC value
    auc = roc_df['auc'].iloc[0] if 'auc' in roc_df.columns else None
    
    # Plot ROC curve
    plt.figure(figsize=figsize)
    
    # Plot the ROC curve
    plt.plot(roc_df['fpr'], roc_df['tpr'], 
             lw=2, color='blue', 
             label=f'ROC Curve (AUC = {auc:.3f})' if auc else 'ROC Curve')
    
    # Plot the random guess line
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', 
             label='Random Guess (AUC = 0.5)')
    
    # Add labels and title
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    
    # Add grid and legend
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(loc='lower right')
    
    # Set axis limits
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    
    # Add annotation for perfect point
    plt.annotate('Perfect Classification', xy=(0, 1), xytext=(0.2, 0.8),
                 arrowprops=dict(facecolor='green', shrink=0.05))
    
    # Add threshold markers (optional)
    if 'threshold' in roc_df.columns:
        # Select a few threshold points to display
        thresholds_to_show = [0.1, 0.3, 0.5, 0.7, 0.9]
        
        # Remove NA values before finding closest points
        valid_threshold_df = roc_df.dropna(subset=['threshold'])
        
        # Only proceed if we have valid thresholds
        if not valid_threshold_df.empty:
            for threshold in thresholds_to_show:
                # Find the closest point without triggering warnings
                # Calculate absolute differences with the target threshold
                abs_diffs = (valid_threshold_df['threshold'] - threshold).abs()
                closest_idx = abs_diffs.idxmin()  # Get the index of the minimum difference
                
                # Get the corresponding point
                point = valid_threshold_df.loc[closest_idx]
                
                # Plot the point and add annotation
                plt.plot(point['fpr'], point['tpr'], 'ro', markersize=6)
                plt.annotate(f"t={threshold:.1f}", 
                            (point['fpr'], point['tpr']),
                            textcoords="offset points",
                            xytext=(0,10),
                            ha='center')
    
    plt.tight_layout()
    
    return plt

In [None]:
# Lift chart
def plot_lift_chart(lift_df, title="Lift Chart and Cumulative Gain", figsize=(14, 8), save_path=None):
    """
    Create lift chart and cumulative gains chart from lift data DF
    
    Parameters:
    lift_df - DF with lift chart data
    title - Plot title (default: "Lift Chart and Cumulative Gain")
    figsize - Figure size (default: (14, 8))
    save_path - Path to save the figure (default: None)
    
    Returns:
    matplotlib figure
    """
    # Convert to pandas if needed
    if not isinstance(lift_df, pd.DataFrame):
        lift_pandas = lift_df.toPandas()
    else:
        lift_pandas = lift_df
    
    # Create fig with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # Cumulative Gain Chart
    ax1.plot(lift_pandas['percentile'], lift_pandas['model_cumulative_gain'], 
             'b-', linewidth=2, label='Model')
    ax1.plot(lift_pandas['percentile'], lift_pandas['random_gain'], 
             'r--', linewidth=2, label='Random')
    
    if 'perfect_gain' in lift_pandas.columns:
        ax1.plot(lift_pandas['percentile'], lift_pandas['perfect_gain'], 
                 'g-.', linewidth=2, label='Perfect')
    
    # Format axes
    ax1.set_xlabel('Percentage of Sample', fontsize=12)
    ax1.set_ylabel('Percentage of Positives Captured', fontsize=12)
    ax1.set_title('Cumulative Gains Chart', fontsize=14)
    ax1.grid(True, linestyle='--', alpha=0.7)
    ax1.legend(fontsize=10)
    
    # Format tick labels as %s
    ax1.xaxis.set_major_formatter(mtick.PercentFormatter())
    ax1.yaxis.set_major_formatter(mtick.PercentFormatter())
    
    # Add reference lines
    ax1.axhline(y=50, color='gray', linestyle='-', alpha=0.3)
    ax1.axhline(y=80, color='gray', linestyle='-', alpha=0.3)
    
    # Find and annotate locations of specific percentages of positives
    for target_gain in [50, 80]:
        # Find closest point to target gain
        if not lift_pandas.empty and 'model_cumulative_gain' in lift_pandas:
            closest_idx = (lift_pandas['model_cumulative_gain'] - target_gain).abs().idxmin()
            x_pct = lift_pandas.loc[closest_idx, 'percentile']
            
            # Add vertical line and annotation
            ax1.axvline(x=x_pct, color='blue', linestyle=':', alpha=0.5)
            ax1.annotate(f'{target_gain}% of positives \ncaptured at {x_pct:.1f}%', 
                         xy=(x_pct, target_gain), xytext=(x_pct+5, target_gain-10),
                         arrowprops=dict(arrowstyle='->', color='blue', alpha=0.7))
    
    # Lift Chart
    ax2.plot(lift_pandas['percentile'], lift_pandas['lift'], 'b-', linewidth=2)
    ax2.axhline(y=1, color='r', linestyle='--', label='No Lift (Random)')
    
    # Format axes
    ax2.set_xlabel('Percentage of Sample', fontsize=12)
    ax2.set_ylabel('Lift', fontsize=12)
    ax2.set_title('Lift Chart', fontsize=14)
    ax2.grid(True, linestyle='--', alpha=0.7)
    ax2.legend(fontsize=10)
    
    # Format tick labels
    ax2.xaxis.set_major_formatter(mtick.PercentFormatter())
    
    # Highlight the areas of highest lift
    if not lift_pandas.empty and 'lift' in lift_pandas:
        # Find point with max lift
        max_lift_idx = lift_pandas['lift'].idxmax()
        max_lift_percentile = lift_pandas.loc[max_lift_idx, 'percentile']
        max_lift_value = lift_pandas.loc[max_lift_idx, 'lift']
        
        # Annotate max lift
        ax2.annotate(f'Max Lift: {max_lift_value:.2f}x at {max_lift_percentile}%', 
                     xy=(max_lift_percentile, max_lift_value), 
                     xytext=(max_lift_percentile+10, max_lift_value),
                     arrowprops=dict(arrowstyle='->', color='blue', alpha=0.7))
    
    plt.tight_layout()
    fig.suptitle(title, y=1.05, fontsize=16)
    
    
    return fig

In [None]:
def plot_cumulative_importance(df, figsize=(12, 6)):
    """
    Show cumulative importance to see how many features 
    account for 80%, 90%, 95% of total importance.
    """
    df_sorted = df.copy()
    df_sorted['cumulative_importance'] = df_sorted['importance'].cumsum()
    df_sorted['cumulative_percentage'] = (df_sorted['cumulative_importance'] / 
                                          df_sorted['importance'].sum()) * 100
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    
    # Bar chart
    top_20 = df_sorted.head(20)
    ax1.barh(range(len(top_20)), top_20['importance'], color='steelblue')
    ax1.set_yticks(range(len(top_20)))
    ax1.set_yticklabels(top_20['feature'])
    ax1.set_xlabel('Importance')
    ax1.set_title('Top 20 Features')
    ax1.invert_yaxis()
    
    # Cumulative line plot
    ax2.plot(range(len(df_sorted)), df_sorted['cumulative_percentage'], 
             color='darkred', linewidth=2)
    ax2.axhline(y=80, color='green', linestyle='--', label='80%')
    ax2.axhline(y=90, color='orange', linestyle='--', label='90%')
    ax2.axhline(y=95, color='red', linestyle='--', label='95%')
    ax2.set_xlabel('Number of Features')
    ax2.set_ylabel('Cumulative Importance (%)')
    ax2.set_title('Cumulative Feature Importance')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print num features needed for thresholds
    for threshold in [80, 90, 95]:
        n_features = (df_sorted['cumulative_percentage'] <= threshold).sum() + 1
        print(f"{n_features} features account for {threshold}% of importance")


In [None]:
def plot_feature_importance_horizontal(df, top_n=20, figsize=(10, 8)):
    """
    Create a horizontal bar chart of top N feature importances.
    """
    top_features = df.head(top_n)
    
    plt.figure(figsize=figsize)
    plt.barh(range(len(top_features)), top_features['importance'], color='steelblue')
    plt.yticks(range(len(top_features)), top_features['feature'])
    plt.xlabel('Importance', fontsize=12)
    plt.ylabel('Feature', fontsize=12)
    plt.title(f'Top {top_n} Feature Importances', fontsize=14, fontweight='bold')
    plt.gca().invert_yaxis()  # Highest importance at top
    plt.tight_layout()
    plt.show()

In [None]:
def plot_calibration_curve(df, n_bins=10, figsize=(10, 6)):
    """
    Check if predicted probabilities match actual frequencies.
    Perfect calibration = diagonal line.
    """
    df['is_error'] = (df['noShow_day1_target'] != df['prediction']).astype(int)
    
    # Create probability bins
    df['prob_bin'] = pd.cut(df['probability_positive'], bins=n_bins)
    
    # Calculate actual rate in each bin
    calibration = df.groupby('prob_bin', observed=True).agg({
        'noShow_day1_target': 'mean',
        'probability_positive': ['mean', 'count']
    }).reset_index()
    calibration.columns = ['prob_bin', 'actual_rate', 'mean_predicted_prob', 'count']
    
    plt.figure(figsize=figsize)
    plt.scatter(calibration['mean_predicted_prob'], calibration['actual_rate'], 
                s=calibration['count']*2, alpha=0.6, color='blue')
    plt.plot([0, 1], [0, 1], 'r--', label='Perfect Calibration')
    plt.xlabel('Mean Predicted Probability', fontsize=12)
    plt.ylabel('Actual No-Show Rate', fontsize=12)
    plt.title('Calibration Curve\n(Point size = number of predictions)', 
              fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print("\n=== Calibration Analysis ===")
    print(calibration.to_string(index=False))

In [None]:
def analyze_high_confidence_errors(df, confidence_threshold=0.7):
    """
    Find predictions where model was very confident but wrong.
    """
    df['is_error'] = (df['noShow_day1_target'] != df['prediction']).astype(int)
    
    # High confidence false positives - confidently predicted no-show, but showed
    high_conf_fp = df[
        (df['is_error'] == 1) & 
        (df['noShow_day1_target'] == 0) & 
        (df['probability_positive'] >= confidence_threshold)
    ].copy()
    
    # High confidence false negatives - confidently predicted show, but no-showed
    high_conf_fn = df[
        (df['is_error'] == 1) & 
        (df['noShow_day1_target'] == 1) & 
        (df['probability_positive'] <= (1 - confidence_threshold))
    ].copy()
    
    print(f"\n=== High Confidence Errors (threshold={confidence_threshold}) ===")
    print(f"High-confidence False Positives: {len(high_conf_fp)} "
          f"({len(high_conf_fp)/len(df)*100:.2f}% of all predictions)")
    print(f"High-confidence False Negatives: {len(high_conf_fn)} "
          f"({len(high_conf_fn)/len(df)*100:.2f}% of all predictions)")
    
    # Plot distribution
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    ax1.hist(high_conf_fp['probability_positive'], bins=30, color='orange', 
             alpha=0.7, edgecolor='black')
    ax1.set_title(f'High-Confidence False Positives (n={len(high_conf_fp)})')
    ax1.set_xlabel('Predicted Probability of No-Show')
    ax1.set_ylabel('Count')
    ax1.axvline(confidence_threshold, color='red', linestyle='--', 
                label=f'Confidence Threshold: {confidence_threshold}')
    ax1.legend()
    
    ax2.hist(high_conf_fn['probability_positive'], bins=30, color='red', 
             alpha=0.7, edgecolor='black')
    ax2.set_title(f'High-Confidence False Negatives (n={len(high_conf_fn)})')
    ax2.set_xlabel('Predicted Probability of No-Show')
    ax2.set_ylabel('Count')
    ax2.axvline(1 - confidence_threshold, color='red', linestyle='--', 
                label=f'Confidence Threshold: {1-confidence_threshold}')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
    return high_conf_fp, high_conf_fn

### Linear Regression Model Plots

In [None]:
lr_pr_curve_data = Dataset.get("lr_pr_curve_data").read_table(format="pandas")
lr_conf_matrix_plot_data = Dataset.get("lr_conf_matrix_plot_data").read_table(format="pandas")
lr_roc_curve_data = Dataset.get("lr_roc_curve_data").read_table(format="pandas")

In [None]:
# LR curve plot
plot_precision_recall_curve(
    lr_pr_curve_data, 
    figsize=(12, 8)
);

In [None]:
plot = plot_confusion_matrix(lr_conf_matrix_plot_data, title="LR Model Confusion Matrix")
plot.show();

In [None]:
plot = plot_roc_curve(lr_roc_curve_data, title="LR Model PROC Curve")
plot.show();

### Multilayer Perceptron Model Eval Plots

In [None]:
mlp_pr_curve_data = Dataset.get("mlp_pr_curve_data").read_table(format="pandas")
mlp_conf_matrix_plot_data = Dataset.get("mlp_conf_matrix_plot_data").read_table(format="pandas")
mlp_roc_curve_data = Dataset.get("mlp_roc_curve_data").read_table(format="pandas")

In [None]:
plot_precision_recall_curve(
    mlp_pr_curve_data, 
    figsize=(12, 8)
);

In [None]:
plot = plot_confusion_matrix(mlp_conf_matrix_plot_data, title="MLP Model Confusion Matrix")
plot.show();

In [None]:
plot = plot_roc_curve(mlp_roc_curve_data, title="MLP Model PROC Curve")
plot.show();

### Random Forest Model Eval Charts

In [None]:
rf_pr_curve_data = Dataset.get("rf1_pr_curve_data").read_table(format="pandas")
rf_conf_matrix_plot_data = Dataset.get("rf1_conf_matrix_plot_data").read_table(format="pandas")
rf_roc_curve_data = Dataset.get("rf1_roc_curve_data").read_table(format="pandas")

In [None]:
plot_precision_recall_curve(
    rf_pr_curve_data, 
    figsize=(12, 8)
);

In [None]:
plot = plot_confusion_matrix(rf_conf_matrix_plot_data, title="RF Model Confusion Matrix")
plot.show();

In [None]:
plot = plot_roc_curve(rf_roc_curve_data, title="RF Model PROC Curve")
plot.show();

### Gradient-Boosted Tree Model Eval Charts

In [None]:
gbt_pr_curve_data = Dataset.get("gbt1_pr_curve_data").read_table(format="pandas")
gbt_conf_matrix_plot_data = Dataset.get("gbt1_conf_matrix_plot_data").read_table(format="pandas")
gbt_roc_curve_data = Dataset.get("gbt1_roc_curve_data").read_table(format="pandas")

In [None]:
plot_precision_recall_curve(
    gbt_pr_curve_data, 
    figsize=(12, 8)
);

In [None]:
plot = plot_confusion_matrix(gbt_conf_matrix_plot_data, title="GBT Model Confusion Matrix")
plot.show();

In [None]:
plot = plot_roc_curve(gbt_roc_curve_data, title="GBT Model PROC Curve")
plot.show();

### Support Vector Machin Model Eval Plots

In [None]:
svm_pr_curve_data = Dataset.get("svm_pr_curve_data").read_table(format="pandas")
svm_conf_matrix_plot_data = Dataset.get("svm_conf_matrix_plot_data").read_table(format="pandas")
svm_roc_curve_data = Dataset.get("svm_roc_curve_data").read_table(format="pandas")

In [None]:
plot_precision_recall_curve(
    svm_pr_curve_data, 
    figsize=(12, 8)
);

In [None]:
plot = plot_confusion_matrix(svm_conf_matrix_plot_data, title="SVM Model Confusion Matrix")
plot.show();

In [None]:
plot = plot_roc_curve(svm_roc_curve_data, title="SVM Model PROC Curve")
plot.show();

### Random Forest Model Feature Importance & Error Eval Plots

In [None]:
rf_feature_importance = Dataset.get("rf_feature_importance").read_table(format="pandas")
rf_predictions = Dataset.get("rf_predictions").read_table(format="pandas")
rf1_lift_chart_data = Dataset.get("rf1_lift_chart_data").read_table(format="pandas")

In [None]:
plot_cumulative_importance(rf_feature_importance)

In [None]:
plot_calibration_curve(rf_predictions)

In [None]:
high_conf_fp, high_conf_fn = analyze_high_confidence_errors(rf_predictions, confidence_threshold=0.7)

In [None]:
lift_chart_fig = plot_lift_chart(
    lift_df=rf1_lift_chart_data,
    title="RF Model Performance: Lift and Cumulative Gain",
    figsize=(14, 7),
)

# Display the plot
lift_chart_fig.show();