
# Customer Churn Prediction (Binary Classification)

This notebook solves a customer churn prediction task using the provided **Customer-Churn.csv** dataset.

We will:

1. Load and explore the dataset (EDA)
2. Clean and preprocess the data
3. Build baseline and ML models (Logistic Regression, Random Forest)
4. Perform minimal hyperparameter tuning
5. Evaluate performance (accuracy, precision, recall, F1, confusion matrix)
6. Interpret important features and discuss limitations


## 1. Setup & Imports

In [None]:

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay
)
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)


## 2. Load & Inspect the Dataset

In [None]:

# Adjust the path if needed (e.g. when running on Kaggle/Colab)
data_path = "Customer-Churn.csv"
df = pd.read_csv(data_path)

print("Shape of dataset:", df.shape)
df.head()


In [None]:

print("Data types:")
print(df.dtypes)

print("\nMissing values per column:")
print(df.isna().sum())


### 2.1 Churn Distribution & Class Imbalance

In [None]:

churn_counts = df['Churn'].value_counts()
churn_ratio = df['Churn'].value_counts(normalize=True) * 100

print("Churn counts:\n", churn_counts)
print("\nChurn percentage (%):\n", churn_ratio.round(2))

fig, ax = plt.subplots()
ax.bar(churn_counts.index, churn_counts.values)
ax.set_title("Churn Distribution")
ax.set_xlabel("Churn")
ax.set_ylabel("Count")
for i, v in enumerate(churn_counts.values):
    ax.text(i, v + 50, str(v), ha='center')
plt.show()


### 2.2 Quick Look at Feature Groups

In [None]:

print("Columns in the dataset:")
print(df.columns.tolist())

demographic_cols = ['gender', 'SeniorCitizen', 'Partner', 'Dependents']
service_cols = ['PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity',
                'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV',
                'StreamingMovies']
contract_cols = ['Contract', 'PaperlessBilling', 'PaymentMethod']
billing_cols = ['tenure', 'MonthlyCharges', 'TotalCharges']

print("\nDemographics:", demographic_cols)
print("Services:", service_cols)
print("Contract & Payment:", contract_cols)
print("Billing:", billing_cols)


### 2.3 Visual Relationships

In [None]:

# Tenure vs Churn
fig, ax = plt.subplots()
df.boxplot(column='tenure', by='Churn', ax=ax)
ax.set_title("Tenure vs Churn")
ax.set_ylabel("Tenure (months)")
plt.suptitle("")
plt.show()

# MonthlyCharges vs Churn
fig, ax = plt.subplots()
df.boxplot(column='MonthlyCharges', by='Churn', ax=ax)
ax.set_title("Monthly Charges vs Churn")
ax.set_ylabel("Monthly Charges")
plt.suptitle("")
plt.show()

# Contract type vs Churn
contract_churn = pd.crosstab(df['Contract'], df['Churn'])
contract_churn_norm = contract_churn.div(contract_churn.sum(axis=1), axis=0)

print("Contract vs Churn counts:")
print(contract_churn)

fig, ax = plt.subplots()
bottom = np.zeros(len(contract_churn_norm.index))
for col in contract_churn_norm.columns:
    ax.bar(contract_churn_norm.index,
           contract_churn_norm[col].values,
           bottom=bottom,
           label=col)
    bottom += contract_churn_norm[col].values
ax.set_title("Contract Type vs Churn (Proportions)")
ax.set_ylabel("Proportion")
ax.legend(title="Churn")
plt.xticks(rotation=15)
plt.show()


### 2.4 Correlation of Numeric Features

In [None]:

numeric_cols = ['SeniorCitizen', 'tenure', 'MonthlyCharges']

# TotalCharges is still object for now, so we don't include it yet
corr = df[numeric_cols].corr()
print(corr)

fig, ax = plt.subplots()
cax = ax.matshow(corr, vmin=-1, vmax=1)
fig.colorbar(cax)
ticks = np.arange(0, len(numeric_cols), 1)
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_xticklabels(numeric_cols, rotation=45, ha='left')
ax.set_yticklabels(numeric_cols)
ax.set_title("Correlation Heatmap (Numeric Features)")
plt.show()


## 3. Data Preprocessing

In [None]:

df_clean = df.copy()

# 3.1 Clean TotalCharges (string-based numeric with possible spaces)
df_clean['TotalCharges'] = pd.to_numeric(df_clean['TotalCharges'], errors='coerce')
print("Number of missing TotalCharges after conversion:", df_clean['TotalCharges'].isna().sum())

# Drop rows where TotalCharges could not be converted (very small fraction)
df_clean = df_clean.dropna(subset=['TotalCharges'])

# 3.2 Simplify 'No phone service' / 'No internet service' categories
service_replace_cols = ['MultipleLines', 'OnlineSecurity', 'OnlineBackup',
                        'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies']

for col in service_replace_cols:
    df_clean[col] = df_clean[col].replace({'No internet service': 'No',
                                           'No phone service': 'No'})

# 3.3 Encode target
df_clean['Churn'] = df_clean['Churn'].map({'No': 0, 'Yes': 1})
y = df_clean['Churn']

# 3.4 Drop identifier
df_clean = df_clean.drop(columns=['customerID'])

# 3.5 Encode binary categorical variables
binary_cols = ['gender', 'Partner', 'Dependents', 'PhoneService', 'PaperlessBilling'] + service_replace_cols

# Map yes/no
yes_no_map = {'Yes': 1, 'No': 0}
for col in binary_cols:
    if col == 'gender':
        df_clean[col] = df_clean[col].map({'Female': 0, 'Male': 1})
    else:
        df_clean[col] = df_clean[col].map(yes_no_map)

# 3.6 One-hot encode multi-category features
multi_cat_cols = ['InternetService', 'Contract', 'PaymentMethod']

X = df_clean.drop(columns=['Churn'])
X = pd.get_dummies(X, columns=multi_cat_cols, drop_first=True)

print("Shape after preprocessing:", X.shape)
print("Remaining dtypes:")
print(X.dtypes.value_counts())


### 3.1 Train / Test Split

In [None]:

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=RANDOM_STATE,
    stratify=y
)

print("Train shape:", X_train.shape, "Test shape:", X_test.shape)
print("Churn proportion in train:")
print(y_train.value_counts(normalize=True).round(3))


## 4. Baseline Model

In [None]:

baseline_clf = DummyClassifier(strategy='most_frequent', random_state=RANDOM_STATE)
baseline_clf.fit(X_train, y_train)
y_pred_baseline = baseline_clf.predict(X_test)

print("Baseline (Most Frequent) Metrics:")
print("Accuracy :", accuracy_score(y_test, y_pred_baseline).round(4))
print("Precision:", precision_score(y_test, y_pred_baseline, zero_division=0).round(4))
print("Recall   :", recall_score(y_test, y_pred_baseline, zero_division=0).round(4))
print("F1       :", f1_score(y_test, y_pred_baseline, zero_division=0).round(4))


## 5. Logistic Regression Model

In [None]:

log_reg = LogisticRegression(
    max_iter=1000,
    class_weight='balanced',
    solver='liblinear',
    random_state=RANDOM_STATE
)

log_reg.fit(X_train, y_train)
y_pred_lr = log_reg.predict(X_test)

print("Logistic Regression Metrics:")
print("Accuracy :", accuracy_score(y_test, y_pred_lr).round(4))
print("Precision:", precision_score(y_test, y_pred_lr).round(4))
print("Recall   :", recall_score(y_test, y_pred_lr).round(4))
print("F1       :", f1_score(y_test, y_pred_lr).round(4))

print("\nClassification Report:")
print(classification_report(y_test, y_pred_lr, digits=4))


## 6. Random Forest with Minimal Hyperparameter Tuning

In [None]:

rf = RandomForestClassifier(
    random_state=RANDOM_STATE,
    class_weight='balanced',
    n_jobs=-1
)

param_grid = {
    'n_estimators': [100, 200],
    'max_depth': [None, 10, 20],
    'min_samples_split': [2, 5]
}

grid_search = GridSearchCV(
    rf,
    param_grid,
    cv=3,
    scoring='f1',
    n_jobs=-1,
    verbose=1
)

grid_search.fit(X_train, y_train)

print("Best parameters:", grid_search.best_params_)
print("Best CV F1 score:", grid_search.best_score_.round(4))

rf_best = grid_search.best_estimator_

y_pred_rf = rf_best.predict(X_test)

print("\nRandom Forest (Tuned) Metrics:")
print("Accuracy :", accuracy_score(y_test, y_pred_rf).round(4))
print("Precision:", precision_score(y_test, y_pred_rf).round(4))
print("Recall   :", recall_score(y_test, y_pred_rf).round(4))
print("F1       :", f1_score(y_test, y_pred_rf).round(4))

print("\nClassification Report:")
print(classification_report(y_test, y_pred_rf, digits=4))


## 7. Confusion Matrices

In [None]:

models = {
    "Baseline (Most Frequent)": y_pred_baseline,
    "Logistic Regression": y_pred_lr,
    "Random Forest (Tuned)": y_pred_rf
}

for name, y_pred in models.items():
    print(f"\n{name}")
    cm = confusion_matrix(y_test, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["No churn", "Churn"])
    disp.plot(values_format='d')
    plt.title(f"Confusion Matrix - {name}")
    plt.show()


## 8. Feature Importance (Random Forest)

In [None]:

importances = rf_best.feature_importances_
feature_importance = pd.Series(importances, index=X_train.columns).sort_values(ascending=False)

print("Top 15 features by importance:")
print(feature_importance.head(15))

fig, ax = plt.subplots(figsize=(8, 6))
top_n = 15
ax.barh(feature_importance.head(top_n).index[::-1],
        feature_importance.head(top_n).values[::-1])
ax.set_title("Top Feature Importances (Random Forest)")
ax.set_xlabel("Importance")
plt.tight_layout()
plt.show()



## 9. Interpretation & Discussion

**Which features influence churn?**  

From the feature importance and coefficients (Random Forest & Logistic Regression), we can usually observe that:

* **Contract type**: Month-to-month contracts tend to have **higher churn**, while longer-term contracts (1 or 2 year) are associated with more stable customers.
* **Tenure**: Customers with a **shorter tenure** (newer customers) are more likely to churn; long-term customers tend to stay.
* **Monthly and total charges**: Higher **MonthlyCharges** often correlate with churn, especially when customers perceive the service as too expensive for the value received.
* **Internet-related services** (e.g., Fiber optic, lack of OnlineSecurity/TechSupport) often play a role, as customers who use many services but are unhappy with quality/price may churn.

---

### Limitations & Class Imbalance Effects

* The dataset is moderately imbalanced (churners are the minority).  
  * We mitigated this via **class_weight='balanced'** for Logistic Regression and Random Forest.
* Metrics like **recall and F1** are more informative than raw accuracy in this setting.
* We performed only **minimal hyperparameter tuning**; more extensive tuning (e.g., RandomizedSearchCV, more parameters) could improve performance.
* We used simple one-hot encoding and standard models. More advanced approaches (e.g. gradient boosting like XGBoost/LightGBM, or neural networks) might yield better performance.
* There may still be some mild data leakage from encoding categories using the full dataset, although the risk is small. In production, using pipelines that fit encoders only on the training set is recommended.

---

### Possible Improvements

1. Use cross-validation and more robust hyperparameter tuning for all models.
2. Try more advanced models (Gradient Boosted Trees, XGBoost, CatBoost).
3. Explore feature engineering (e.g., bucketizing tenure, interaction terms).
4. Apply resampling techniques (SMOTE, undersampling) and compare with class weights.
5. Deploy as an API and monitor performance on real-time data.
