# Customer Churn Prediction - Model Interpretability

This notebook implements model interpretability techniques to understand the factors driving customer churn predictions.

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
import pickle
import shap
from sklearn.inspection import permutation_importance
from sklearn.metrics import confusion_matrix

# Set up plotting
%matplotlib inline
plt.style.use('seaborn-whitegrid')
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

# Create directories for saving outputs
os.makedirs('../docs/plots', exist_ok=True)

## 1. Load Data and Models

In [None]:
# Load the engineered data
df_engineered = pd.read_csv('../data/processed/churn_engineered.csv')

# Display basic information
print(f"Dataset shape: {df_engineered.shape}")
print(f"\nSample data:")
df_engineered.head()

In [None]:
# Load feature sets
with open('../models/feature_sets.json', 'r') as f:
    feature_sets = json.load(f)

# Load best model information
with open('../models/best_model_info.json', 'r') as f:
    best_model_info = json.load(f)

print(f"Best model: {best_model_info['model_name']}")
print(f"Feature set: {best_model_info['feature_set']}")
print(f"Metrics: {best_model_info['metrics']}")

In [None]:
# Load the best model
model_path = f"../models/{best_model_info['model_name'].lower().replace(' ', '_')}.pkl"
with open(model_path, 'rb') as f:
    best_model = pickle.load(f)

print(f"Loaded model from {model_path}")

## 2. Prepare Data for Interpretability

In [None]:
# Function to prepare data for a specific feature set
def prepare_data(df, feature_set_name, test_size=0.2, random_state=42):
    from sklearn.model_selection import train_test_split
    
    # Get features for the specified feature set
    features = feature_sets[feature_set_name]
    
    # Prepare features and target
    X = df[features]
    y = df['Exited']
    
    # Split data into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state, stratify=y)
    
    return X_train, X_test, y_train, y_test

In [None]:
# Prepare data for interpretability
feature_set_name = best_model_info['feature_set']
X_train, X_test, y_train, y_test = prepare_data(df_engineered, feature_set_name)

# Get feature names
feature_names = X_train.columns.tolist()
print(f"Number of features: {len(feature_names)}")
print(f"Features: {feature_names}")

## 3. Coefficient Analysis (for Logistic Regression)

In [None]:
# Check if the best model is Logistic Regression
if best_model_info['model_name'] == 'Logistic Regression':
    # Get coefficients
    coefficients = best_model.coef_[0]
    intercept = best_model.intercept_[0]
    
    # Create DataFrame with coefficients
    coef_df = pd.DataFrame({
        'Feature': feature_names,
        'Coefficient': coefficients
    })
    
    # Sort by absolute coefficient value
    coef_df['Abs_Coefficient'] = coef_df['Coefficient'].abs()
    coef_df = coef_df.sort_values('Abs_Coefficient', ascending=False).reset_index(drop=True)
    
    # Display coefficients
    print(f"Intercept: {intercept:.4f}")
    print("\nTop coefficients:")
    coef_df[['Feature', 'Coefficient']].head(20)
else:
    print("The best model is not Logistic Regression. Skipping coefficient analysis.")

In [None]:
# Visualize coefficients
if best_model_info['model_name'] == 'Logistic Regression':
    plt.figure(figsize=(12, 10))
    
    # Plot top 15 coefficients
    top_coef_df = coef_df.head(15).copy()
    top_coef_df['Color'] = top_coef_df['Coefficient'].apply(lambda x: 'red' if x > 0 else 'green')
    
    # Sort by coefficient value for better visualization
    top_coef_df = top_coef_df.sort_values('Coefficient')
    
    plt.barh(top_coef_df['Feature'], top_coef_df['Coefficient'], color=top_coef_df['Color'])
    plt.axvline(x=0, color='black', linestyle='--')
    plt.xlabel('Coefficient Value', fontsize=12)
    plt.ylabel('Feature', fontsize=12)
    plt.title('Top 15 Logistic Regression Coefficients', fontsize=15)
    plt.grid(True, axis='x')
    plt.tight_layout()
    plt.savefig('../docs/plots/logistic_regression_coefficients.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Save coefficients to CSV
    coef_df[['Feature', 'Coefficient']].to_csv('../docs/logistic_regression_coefficients.csv', index=False)
    print("Coefficients saved to ../docs/logistic_regression_coefficients.csv")

## 4. Permutation Importance

In [None]:
# Calculate permutation importance
perm_importance = permutation_importance(best_model, X_test, y_test, n_repeats=10, random_state=42)

# Create DataFrame with permutation importance
perm_importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': perm_importance.importances_mean,
    'Std': perm_importance.importances_std
})

# Sort by importance
perm_importance_df = perm_importance_df.sort_values('Importance', ascending=False).reset_index(drop=True)

# Display permutation importance
print("Permutation Importance:")
perm_importance_df.head(15)

In [None]:
# Visualize permutation importance
plt.figure(figsize=(12, 10))
top_perm_importance = perm_importance_df.head(15)

# Sort by importance for better visualization
top_perm_importance = top_perm_importance.sort_values('Importance')

plt.barh(top_perm_importance['Feature'], top_perm_importance['Importance'], 
         xerr=top_perm_importance['Std'], capsize=5, color='skyblue')
plt.xlabel('Permutation Importance', fontsize=12)
plt.ylabel('Feature', fontsize=12)
plt.title('Top 15 Features by Permutation Importance', fontsize=15)
plt.grid(True, axis='x')
plt.tight_layout()
plt.savefig('../docs/plots/permutation_importance.png', dpi=300, bbox_inches='tight')
plt.show()

# Save permutation importance to CSV
perm_importance_df.to_csv('../docs/permutation_importance.csv', index=False)
print("Permutation importance saved to ../docs/permutation_importance.csv")

## 5. SHAP Analysis

In [None]:
# Create SHAP explainer
if best_model_info['model_name'] == 'Logistic Regression':
    # For Logistic Regression, use LinearExplainer
    explainer = shap.LinearExplainer(best_model, X_train)
else:
    # For other models, use KernelExplainer
    explainer = shap.KernelExplainer(best_model.predict_proba, shap.sample(X_train, 100))

# Calculate SHAP values for a sample of test data
X_test_sample = X_test.sample(100, random_state=42)
shap_values = explainer.shap_values(X_test_sample)

# For classification models, shap_values might be a list with values for each class
if isinstance(shap_values, list):
    # Use values for class 1 (churn)
    shap_values = shap_values[1]

In [None]:
# SHAP summary plot
plt.figure(figsize=(12, 10))
shap.summary_plot(shap_values, X_test_sample, plot_type="bar", show=False)
plt.title('SHAP Feature Importance', fontsize=15)
plt.tight_layout()
plt.savefig('../docs/plots/shap_feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# SHAP summary plot with feature values
plt.figure(figsize=(12, 10))
shap.summary_plot(shap_values, X_test_sample, show=False)
plt.title('SHAP Summary Plot', fontsize=15)
plt.tight_layout()
plt.savefig('../docs/plots/shap_summary_plot.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# SHAP dependence plots for top features
top_features = perm_importance_df['Feature'].head(5).tolist()

for feature in top_features:
    plt.figure(figsize=(10, 6))
    shap.dependence_plot(feature, shap_values, X_test_sample, show=False)
    plt.title(f'SHAP Dependence Plot for {feature}', fontsize=15)
    plt.tight_layout()
    plt.savefig(f'../docs/plots/shap_dependence_{feature}.png', dpi=300, bbox_inches='tight')
    plt.show()

## 6. Customer Segment Analysis

In [None]:
# Analyze churn rates by different customer segments

# Age groups
df_engineered['AgeGroup'] = pd.cut(df_engineered['Age'], bins=[0, 30, 40, 50, 60, 100], labels=['<30', '30-40', '40-50', '50-60', '>60'])
age_group_churn = df_engineered.groupby('AgeGroup')['Exited'].agg(['count', 'mean'])
age_group_churn['mean'] = age_group_churn['mean'] * 100  # Convert to percentage
age_group_churn.columns = ['Count', 'Churn Rate (%)']  # Rename columns

print("Churn rate by age group:")
print(age_group_churn)

# Geography
geography_churn = df_engineered.groupby(['Geography_France', 'Geography_Germany', 'Geography_Spain'])['Exited'].agg(['count', 'mean'])
geography_churn['mean'] = geography_churn['mean'] * 100  # Convert to percentage
geography_churn.columns = ['Count', 'Churn Rate (%)']  # Rename columns

# Create a more readable index
geography_mapping = {
    (1, 0, 0): 'France',
    (0, 1, 0): 'Germany',
    (0, 0, 1): 'Spain'
}
geography_churn = geography_churn.rename(index=geography_mapping)

print("\nChurn rate by geography:")
print(geography_churn)

# Activity status
activity_churn = df_engineered.groupby('IsActiveMember')['Exited'].agg(['count', 'mean'])
activity_churn['mean'] = activity_churn['mean'] * 100  # Convert to percentage
activity_churn.columns = ['Count', 'Churn Rate (%)']  # Rename columns
activity_churn.index = ['Inactive', 'Active']  # Rename index

print("\nChurn rate by activity status:")
print(activity_churn)

# Number of products
product_churn = df_engineered.groupby('NumOfProducts')['Exited'].agg(['count', 'mean'])
product_churn['mean'] = product_churn['mean'] * 100  # Convert to percentage
product_churn.columns = ['Count', 'Churn Rate (%)']  # Rename columns

print("\nChurn rate by number of products:")
print(product_churn)

In [None]:
# Visualize churn rates by customer segments
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Age groups
age_group_churn['Churn Rate (%)'].plot(kind='bar', ax=axes[0, 0], color='skyblue')
axes[0, 0].set_title('Churn Rate by Age Group', fontsize=14)
axes[0, 0].set_xlabel('Age Group', fontsize=12)
axes[0, 0].set_ylabel('Churn Rate (%)', fontsize=12)
axes[0, 0].grid(axis='y')
for i, v in enumerate(age_group_churn['Churn Rate (%)']):
    axes[0, 0].text(i, v + 1, f"{v:.1f}%", ha='center', fontsize=10)

# Geography
geography_churn['Churn Rate (%)'].plot(kind='bar', ax=axes[0, 1], color='lightgreen')
axes[0, 1].set_title('Churn Rate by Geography', fontsize=14)
axes[0, 1].set_xlabel('Geography', fontsize=12)
axes[0, 1].set_ylabel('Churn Rate (%)', fontsize=12)
axes[0, 1].grid(axis='y')
for i, v in enumerate(geography_churn['Churn Rate (%)']):
    axes[0, 1].text(i, v + 1, f"{v:.1f}%", ha='center', fontsize=10)

# Activity status
activity_churn['Churn Rate (%)'].plot(kind='bar', ax=axes[1, 0], color='salmon')
axes[1, 0].set_title('Churn Rate by Activity Status', fontsize=14)
axes[1, 0].set_xlabel('Activity Status', fontsize=12)
axes[1, 0].set_ylabel('Churn Rate (%)', fontsize=12)
axes[1, 0].grid(axis='y')
for i, v in enumerate(activity_churn['Churn Rate (%)']):
    axes[1, 0].text(i, v + 1, f"{v:.1f}%", ha='center', fontsize=10)

# Number of products
product_churn['Churn Rate (%)'].plot(kind='bar', ax=axes[1, 1], color='mediumpurple')
axes[1, 1].set_title('Churn Rate by Number of Products', fontsize=14)
axes[1, 1].set_xlabel('Number of Products', fontsize=12)
axes[1, 1].set_ylabel('Churn Rate (%)', fontsize=12)
axes[1, 1].grid(axis='y')
for i, v in enumerate(product_churn['Churn Rate (%)']):
    axes[1, 1].text(i, v + 1, f"{v:.1f}%", ha='center', fontsize=10)

plt.tight_layout()
plt.savefig('../docs/plots/churn_rate_by_segments.png', dpi=300, bbox_inches='tight')
plt.show()

## 7. High-Risk Customer Profiles

In [None]:
# Make predictions on entire dataset
X_full = df_engineered[feature_names]
y_pred_proba = best_model.predict_proba(X_full)[:, 1]

# Identify high-risk customers (top 10% by probability)
high_risk_threshold = np.percentile(y_pred_proba, 90)
high_risk_customers = df_engineered[y_pred_proba >= high_risk_threshold]

# Analyze characteristics of high-risk customers
high_risk_profile = high_risk_customers.describe()
overall_profile = df_engineered.describe()

# Compare high-risk profile to overall population
profile_comparison = pd.concat([high_risk_profile, overall_profile], axis=1, keys=['High Risk', 'Overall'])
profile_comparison = profile_comparison.loc['mean']

# Visualize high-risk profile
plt.figure(figsize=(12, 8))
profile_comparison.plot(kind='bar')
plt.title('High-Risk vs Overall Customer Profile')
plt.xlabel('Features')
plt.ylabel('Mean Value')
plt.legend(['High Risk', 'Overall'])
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('../docs/plots/high_risk_profile.png', dpi=300, bbox_inches='tight')
plt.show()

## 8. Business Insights and Recommendations

In [None]:
# Compile key churn factors
key_factors = perm_importance_df['Feature'].head(5).tolist()

# Identify high-risk segments
high_risk_segments = {
    'Age': high_risk_customers['Age'].mean(),
    'Geography': high_risk_customers['Geography'].mode().values[0],
    'IsActiveMember': high_risk_customers['IsActiveMember'].mean(),
    'NumOfProducts': high_risk_customers['NumOfProducts'].mean()
}

# Provide business recommendations
recommendations = [
    "Focus retention efforts on customers with profiles similar to the high-risk segment",
    f"Pay special attention to customers in {high_risk_segments['Geography']}",
    "Develop targeted retention strategies for inactive members",
    "Consider offering product bundles to increase the number of products per customer",
    "Implement a proactive outreach program for customers as they approach the average high-risk age"
]

# Save insights to JSON
insights = {
    'key_churn_factors': key_factors,
    'high_risk_segments': high_risk_segments,
    'recommendations': recommendations
}

with open('../docs/business_insights.json', 'w') as f:
    json.dump(insights, f, indent=4)

print("Business insights and recommendations saved to ../docs/business_insights.json")

## Individual Customer Prediction Analysis

In [None]:
def analyze_customer(customer_data):
    # Make prediction
    prediction = best_model.predict_proba(customer_data.reshape(1, -1))[0, 1]
    
    # Get SHAP values
    explainer = shap.Explainer(best_model, X_train)
    shap_values = explainer(customer_data.reshape(1, -1))
    
    # Identify top risk factors
    risk_factors = pd.DataFrame({
        'feature': feature_names,
        'importance': np.abs(shap_values.values[0])
    }).sort_values('importance', ascending=False).head(3)
    
    # Generate recommendations
    recommendations = []
    for _, factor in risk_factors.iterrows():
        if factor['importance'] > 0:
            if 'Age' in factor['feature']:
                recommendations.append("Consider age-specific retention offers")
            elif 'IsActiveMember' in factor['feature']:
                recommendations.append("Encourage more active engagement with our services")
            elif 'NumOfProducts' in factor['feature']:
                recommendations.append("Offer additional products that complement current usage")
    
    return {
        'churn_probability': prediction,
        'risk_factors': risk_factors.to_dict('records'),
        'recommendations': recommendations
    }

# Analyze sample high-risk and low-risk customers
high_risk_sample = X_test.iloc[y_pred_proba[X_test.index].argmax()]
low_risk_sample = X_test.iloc[y_pred_proba[X_test.index].argmin()]

print("High-risk customer analysis:")
print(json.dumps(analyze_customer(high_risk_sample), indent=2))

print("\nLow-risk customer analysis:")
print(json.dumps(analyze_customer(low_risk_sample), indent=2))