In [None]:
import os
from pathlib import Path
import warnings
import pickle

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report, roc_curve, auc

warnings.filterwarnings("ignore")

# === CONFIG ===
# This line is where the filename is set
INPUT_CSV = "DataWave_Music_Sprint_Dataset.csv" 
OUTPUT_DIR = Path(".") 
TODAY = pd.Timestamp("2025-11-25")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# === HELPERS (Excluding safe_read_csv) ===
def save_plot(fig, fname):
    path = OUTPUT_DIR / fname
    fig.tight_layout()
    fig.savefig(path, dpi=150)
    print(f"Saved plot: {path}")

# === LOAD (Using direct pandas read) ===
print("Loading data from:", INPUT_CSV)
try:
    # Direct read, confirmed to work in the execution environment
    df = pd.read_csv(INPUT_CSV)
except FileNotFoundError as e:
    # If this still fails in your local environment, you may need to 
    # adjust the INPUT_CSV path to include your file's local directory.
    print(f"Error reading file: {e}")
    print("Please ensure 'DataWave_Music_Sprint_Dataset.csv' is in the same folder as your script.")
    raise

print("Initial shape:", df.shape)


# === STANDARDIZE COLUMN NAMES & COMMON RENAMES ===
df.columns = df.columns.str.strip().str.lower().str.replace(" ", "_")
col_map = {
    'avg_listening_hours_per_week': 'avg_listening_hours_per_week',
    'hours_listened': 'avg_listening_hours_per_week',
    'join_date': 'join_date',
    'datejoined': 'join_date',
    'date_joined': 'join_date',
    'last_active_date': 'last_active',
    'churn': 'churned',
    'churned?': 'churned'
}
df.rename(columns=col_map, inplace=True)

# Ensure columns exist
for c in ['subscription_type','gender','country']:
    if c not in df.columns:
        df[c] = 'unknown'

# === STRING CLEANING: subscription_type, gender, country ===
df['subscription_type'] = df['subscription_type'].astype(str).str.strip().str.lower()
subs_fix = {
    'premum': 'premium', 'premiun': 'premium', 'prem': 'premium',
    'premu m': 'premium', 'premium ': 'premium',
    'fam': 'family', 'family ': 'family',
    'studnt': 'student', 'student ': 'student'
}
df['subscription_type'] = df['subscription_type'].replace(subs_fix).fillna('unknown')

df['gender'] = df['gender'].astype(str).str.strip().str.lower().replace({'f':'female','m':'male','nan':'unknown','': 'unknown'}).fillna('unknown')
df['country'] = df['country'].astype(str).str.strip().replace({'': 'Unknown', 'nan':'Unknown'}).fillna('Unknown').str.title()

# === NUMERIC CONVERSIONS ===
num_map = {
    'age': 'age',
    'avg_listening_hours_per_week': 'avg_listening_hours_per_week',
    'total_songs_played': 'total_songs_played',
    'satisfaction_score': 'satisfaction_score',
    'monthly_fee': 'monthly_fee'
}
for col in num_map:
    if col in df.columns:
        # Include replacement for 'USD' and commas
        df[col] = pd.to_numeric(
            df[col].astype(str).str.replace('USD', '', regex=False).str.replace(',', '').str.strip(),
            errors='coerce'
        )

# === SKIP RATE PARSING ===
if 'skip_rate' in df.columns:
    def parse_skip(x):
        if pd.isna(x):
            return np.nan
        s = str(x).strip()
        # Remove percent sign
        if s.endswith('%'):
            s2 = s[:-1].strip()
            try:
                return float(s2)
            except:
                return np.nan
        # if decimal between 0 and 1 treat as fraction -> percent
        try:
            v = float(s)
            if 0 <= v <= 1:
                return v * 100.0
            return v
        except:
            return np.nan
    df['skip_rate_pct'] = df['skip_rate'].apply(parse_skip)
else:
    df['skip_rate_pct'] = np.nan

# If skip_rate_pct entirely missing, create column and fill later
df['skip_rate_pct'] = pd.to_numeric(df['skip_rate_pct'], errors='coerce')

# === CHURN CONVERSION ===
if 'churned' in df.columns:
    df['churned'] = df['churned'].astype(str).str.lower().str.strip()
    df['churned'] = df['churned'].replace({'yes':1, 'no':0, 'true':1, 'false':0, '1':1, '0':0})
    df['churned'] = pd.to_numeric(df['churned'], errors='coerce')
    df['churned'] = df['churned'].fillna(0).astype(int)
else:
    df['churned'] = 0

# === DATES: join_date and tenure ===
if 'join_date' in df.columns:
    # try parse with default (month-first) then fallback to day-first for unparsed strings
    df['join_date_parsed'] = pd.to_datetime(df['join_date'], errors='coerce', dayfirst=False)
    mask_na = df['join_date_parsed'].isna() & df['join_date'].notna()
    if mask_na.any():
        try:
            df.loc[mask_na, 'join_date_parsed'] = pd.to_datetime(df.loc[mask_na, 'join_date'], errors='coerce', dayfirst=True)
        except Exception:
            pass
    df['join_date'] = df['join_date_parsed']
    df.drop(columns=['join_date_parsed'], inplace=True)
else:
    df['join_date'] = pd.NaT

df['tenure_days'] = (TODAY - df['join_date']).dt.days
# keep tenure_days numeric (NaN where join_date missing)

# === DUPLICATES: user_id ===
if 'user_id' in df.columns:
    dup_count = df['user_id'].duplicated().sum()
    print("Duplicate user_id count:", dup_count)
    if dup_count > 0:
        df = df.drop_duplicates(subset=['user_id'], keep='first')

# === HANDLE MISSING VALUES ===
# Fill categorical with 'unknown'
df['subscription_type'] = df['subscription_type'].fillna('unknown')
df['gender'] = df['gender'].fillna('unknown')
df['country'] = df['country'].fillna('Unknown')

df['monthly_fee'] = df['monthly_fee'].fillna(0.0)

# Fill numeric with medians (if median NaN -> 0)
numeric_fill_cols = [
    'age','avg_listening_hours_per_week','total_songs_played',
    'satisfaction_score','monthly_fee','skip_rate_pct','tenure_days'
]
for c in numeric_fill_cols:
    if c in df.columns:
        med = df[c].median()
        if pd.isna(med):
            med = 0.0
        df[c] = df[c].fillna(med)

# === REMOVE UNREALISTIC OUTLIERS ===
if 'age' in df.columns:
    df = df[(df['age'] >= 10) & (df['age'] <= 90)]
if 'avg_listening_hours_per_week' in df.columns:
    df = df[(df['avg_listening_hours_per_week'] >= 0) & (df['avg_listening_hours_per_week'] <= 200)]

print("Shape after cleaning:", df.shape)

# === FINAL SUBSCRIPTION NORMALIZATION ===
df['subscription_type'] = df['subscription_type'].replace({
    'fam': 'family', 'family ': 'family', 'studnt': 'student', 'student ': 'student'
}).str.lower()

# === SAVE CLEANED CSV ===
clean_path = OUTPUT_DIR / "datawave_music_cleaned_full.csv"
df.to_csv(clean_path, index=False)
print("Saved cleaned CSV to:", clean_path)



# === EXPLORATORY ANALYSIS & PLOTS ===
overall_churn = df['churned'].mean()
print(f"Overall churn rate: {overall_churn:.4f}")

# Churn rate  by subscription type
churn_by_sub = df.groupby('subscription_type')['churned'].mean().sort_values(ascending=False)
print("\nChurn rate by subscription_type:\n", churn_by_sub)

# plot churn by subscription
fig = plt.figure(figsize=(8,4))
ax = churn_by_sub.plot(kind='bar', rot=30)
ax.set_title("Churn rate by subscription_type")
ax.set_ylabel("Churn rate")
save_plot(fig, "plot_churn_by_subscription.png")
plt.close(fig)

# listening distribution
if 'avg_listening_hours_per_week' in df.columns:
    fig = plt.figure(figsize=(7,4))
    df['avg_listening_hours_per_week'].plot(kind='hist', bins=30)
    plt.title("Distribution of Avg Listening Hours per Week")
    plt.xlabel("Avg listening hours per week")
    save_plot(fig, "plot_listening_dist.png")
    plt.close(fig)

# satisfaction by churn boxplot
if 'satisfaction_score' in df.columns:
    fig = plt.figure(figsize=(6,4))
    df.loc[:,['satisfaction_score', 'churned']].boxplot(column='satisfaction_score', by='churned', ax=fig.gca())
    plt.title("Satisfaction score by churn status")
    plt.suptitle("")
    plt.xlabel("Churned (0=No, 1=Yes)")
    plt.ylabel("Satisfaction score")
    save_plot(fig, "plot_satisfaction_by_churn.png")
    plt.close(fig)

# avg listening by churn
if 'avg_listening_hours_per_week' in df.columns:
    fig = plt.figure(figsize=(6,4))
    df.groupby('churned')['avg_listening_hours_per_week'].mean().plot(kind='bar', rot=0)
    plt.title("Avg listening hours by churn status")
    plt.xlabel("Churned (0=No, 1=Yes)")
    plt.ylabel("Avg listening hours per week")
    save_plot(fig, "plot_listening_by_churn.png")
    plt.close(fig)

# skip rate by churn
if 'skip_rate_pct' in df.columns:
    fig = plt.figure(figsize=(6,4))
    df.groupby('churned')['skip_rate_pct'].mean().plot(kind='bar', rot=0)
    plt.title("Skip rate (%) by churn status")
    plt.xlabel("Churned (0=No, 1=Yes)")
    plt.ylabel("Skip rate (%)")
    save_plot(fig, "plot_skiprate_by_churn.png")
    plt.close(fig)

# age distribution
if 'age' in df.columns:
    fig = plt.figure(figsize=(7,4))
    df['age'].plot(kind='hist', bins=20)
    plt.title("Age distribution")
    plt.xlabel("Age")
    save_plot(fig, "plot_age_dist.png")
    plt.close(fig)

# === SIMPLE CHURN MODEL (Logistic Regression) ===
print("\nTraining churn model (Logistic Regression) with class_weight='balanced'...")

feature_cols = ['age','avg_listening_hours_per_week','total_songs_played','skip_rate_pct','satisfaction_score','monthly_fee','tenure_days']
feature_cols = [c for c in feature_cols if c in df.columns]

if not feature_cols:
    raise RuntimeError("No numeric features available for modeling.")

X = df[feature_cols + ['subscription_type','gender']].copy()
y = df['churned'].copy()

# One-hot categorical, scale numeric
cat_cols = ['subscription_type','gender']
num_cols = feature_cols

preprocessor = ColumnTransformer(transformers=[
    ('num', StandardScaler(), num_cols),
    ('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), cat_cols)
], remainder='drop')

# ADDED: class_weight='balanced' to handle class imbalance
clf = Pipeline(steps=[
    ('pre', preprocessor),
    ('clf', LogisticRegression(max_iter=500, solver='liblinear', class_weight='balanced'))
])

# Train/test split
try:
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.25, random_state=42)
except ValueError:
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
y_proba = clf.predict_proba(X_test)[:,1] if hasattr(clf.named_steps['clf'], "predict_proba") else np.zeros(len(y_test))

acc = accuracy_score(y_test, y_pred)
roc = roc_auc_score(y_test, y_proba) if len(np.unique(y_test)) > 1 else float('nan')
print("Accuracy:", acc)
print("ROC AUC:", roc)
print("\nClassification report:\n", classification_report(y_test, y_pred, zero_division=0))

# Save model
model_path = OUTPUT_DIR / "churn_logreg_model_final_balanced.pkl"
with open(model_path, "wb") as f:
    pickle.dump(clf, f)
print("Saved model to:", model_path)

# Plot ROC if possible
if len(np.unique(y_test)) > 1:
    fpr, tpr, _ = roc_curve(y_test, y_proba)
    roc_auc = auc(fpr, tpr)
    fig = plt.figure(figsize=(6,5))
    plt.plot(fpr, tpr, label=f"AUC={roc_auc:.3f}")
    plt.plot([0,1],[0,1], linestyle='--')
    plt.title("ROC curve")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend()
    save_plot(fig, "plot_roc_curve_balanced.png")
    plt.close(fig)

# Save coefficients (numeric + ohe names)
ohe = clf.named_steps['pre'].named_transformers_['cat']
try:
    ohe_names = list(ohe.get_feature_names_out(cat_cols))
except Exception:
    ohe_names = []
transformed_cols = num_cols + ohe_names
coefs = clf.named_steps['clf'].coef_[0]
coef_df = pd.DataFrame({'feature': transformed_cols, 'coef': coefs[:len(transformed_cols)]})
coef_out = OUTPUT_DIR / "model_coefficients_balanced.csv"
coef_df.to_csv(coef_out, index=False)
print("Saved model coefficients to:", coef_out)

print("\nPipeline complete. Key outputs in:", OUTPUT_DIR)
print("Cleaned CSV:", clean_path)
print("Model file:", model_path)

Duplicate user_id count: 0
Shape after cleaning: (691, 14)
Saved cleaned CSV to: datawave_music_cleaned_full.csv
Overall churn rate: 0.3054

Churn rate by subscription_type:
 subscription_type
family     0.323671
premium    0.315217
student    0.306533
free       0.247525
Name: churned, dtype: float64
Saved plot: plot_churn_by_subscription.png
Saved plot: plot_listening_dist.png
Saved plot: plot_satisfaction_by_churn.png
Saved plot: plot_listening_by_churn.png
Saved plot: plot_skiprate_by_churn.png
Saved plot: plot_age_dist.png

Training churn model (Logistic Regression) with class_weight='balanced'...
Accuracy: 0.4797687861271676
ROC AUC: 0.4979559748427673

Classification report:
               precision    recall  f1-score   support

           0       0.67      0.49      0.57       120
           1       0.28      0.45      0.35        53

    accuracy                           0.48       173
   macro avg       0.48      0.47      0.46       173
weighted avg       0.55      0.48   