# 02 - Predictions

Generate churn probability predictions per customer and show a ranked list.

In [None]:
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

df = pd.read_csv("../data/churn.csv")

customer_ids = df["customer_id"].copy()

data = df.copy()
data["churn"] = data["churn"].map({"Yes": 1, "No": 0})
data["total_charges"] = pd.to_numeric(data["total_charges"], errors="coerce")
data["total_charges"] = data["total_charges"].fillna(data["total_charges"].median())

X = data[["monthly_charges", "total_charges"]]
y = data["churn"]

# Train model
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)
model = LogisticRegression(max_iter=1000, class_weight="balanced")
model.fit(X_train, y_train)


In [None]:
# Predict churn probability for each customer
churn_prob = model.predict_proba(X)[:, 1]

results = pd.DataFrame({
    "Customer": customer_ids,
    "Churn probability": churn_prob
})

results_filtered = results[results["Churn probability"] >= 0.50]

results_sorted = results_filtered.sort_values("Churn probability", ascending=False)
results_sorted["Churn probability"] = results_sorted["Churn probability"].round(2)

print(results_sorted.head(10).to_string(index=False))


Customer  Churn probability
 C008283               0.62
 C005035               0.61
 C003399               0.61
 C001652               0.61
 C008676               0.61
 C001909               0.61
 C001526               0.61
 C005916               0.61
 C004952               0.61
 C009349               0.61


In [None]:
import ipywidgets as widgets
from IPython.display import display

export_button = widgets.Button(description='Export to CSV', button_style='success')
status = widgets.Output()

def export_to_csv(_):
    output_path = '../data/churn_predictions.csv'
    results_sorted.to_csv(output_path, index=False)
    with status:
        status.clear_output()
        print(f'Exported to {output_path}')

export_button.on_click(export_to_csv)
display(export_button, status)


Button(button_style='success', description='Export to CSV', style=ButtonStyle())

Output()