In [1]:
# Importing the necessary packages
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from sqlalchemy import text
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.utils import median_survival_times
import matplotlib.pyplot as plt
import seaborn as sns
from Database.database import engine, SessionLocal
from Database.models import FactUserAnalyticsSnapshot, DimUser

pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)

print("Imported successfully")

Imported successfully


In [2]:
# Loading RFM + Churn Data for survival analysis
snapshot_date_key = int(datetime.now().strftime("%Y%m%d"))

print("="*80)
print("LOADING DATA FOR SURVIVAL ANALYSIS")
print("="*80)

# Loading data using ORM
print("Loading data from fact_user_analytics_snapshot...")
with SessionLocal() as session:
    # Load user analytics with signup dates
    records = session.query(
        FactUserAnalyticsSnapshot.user_key,
        FactUserAnalyticsSnapshot.subscription_plan_key,
        FactUserAnalyticsSnapshot.rfm_recency,
        FactUserAnalyticsSnapshot.rfm_frequency,
        FactUserAnalyticsSnapshot.rfm_monetary,
        FactUserAnalyticsSnapshot.segment_label,
        FactUserAnalyticsSnapshot.engagement_level,
        FactUserAnalyticsSnapshot.churn_probability,
        FactUserAnalyticsSnapshot.churn_risk_band,
        DimUser.signup_date_key
    ).join(
        DimUser, 
        FactUserAnalyticsSnapshot.user_key == DimUser.user_key
    ).filter(
        FactUserAnalyticsSnapshot.snapshot_date_key == snapshot_date_key,
        FactUserAnalyticsSnapshot.subscription_plan_key.in_([2, 3, 4, 5])
    ).all()
    
    data = [{
        'user_key': r.user_key,
        'subscription_plan_key': r.subscription_plan_key,
        'rfm_recency': r.rfm_recency,
        'rfm_frequency': r.rfm_frequency,
        'rfm_monetary': r.rfm_monetary,
        'segment_label': r.segment_label,
        'engagement_level': r.engagement_level,
        'churn_probability': r.churn_probability,
        'churn_risk_band': r.churn_risk_band,
        'signup_date_key': r.signup_date_key
    } for r in records]

df = pd.DataFrame(data)

print(f"Loaded {len(df):,} premium users")

# Survival analysis variables
def date_key_to_datetime(date_key):
    """Convert date_key (YYYYMMDD) to datetime"""
    date_str = str(date_key)
    return pd.to_datetime(date_str, format='%Y%m%d')

df['signup_date'] = df['signup_date_key'].apply(date_key_to_datetime)
df['snapshot_date'] = pd.to_datetime(str(snapshot_date_key), format='%Y%m%d')
df['duration'] = (df['snapshot_date'] - df['signup_date']).dt.days

def define_churn_event(row):
    """
    Define churn using multiple objective criteria
    """
    # High recency (inactive)
    if row['rfm_recency'] > 60:
        return 1
    
    # High predicted churn probability
    if row['churn_probability'] > 0.7:
        return 1
    
    # Low frequency (never really engaged)
    if row['rfm_frequency'] < 7:
        return 1
    
    # Explicitly churned segment
    if row['segment_label'] in ['Recently Churned', 'Dormant Premium']:
        return 1
    
    # Otherwise, still active
    return 0

df['event'] = df.apply(define_churn_event, axis=1)
df = df[df['duration'] > 0]

print("\n Survival Data Summary:")
print(f"  Total users:          {len(df):,}")
print(f"  Churned (event=1):    {df['event'].sum():,} ({df['event'].sum()/len(df)*100:.1f}%)")
print(f"  Active (censored):    {(df['event']==0).sum():,} ({(df['event']==0).sum()/len(df)*100:.1f}%)")
print(f"  Avg duration:         {df['duration'].mean():.1f} days")
print(f"  Median duration:      {df['duration'].median():.1f} days")


LOADING DATA FOR SURVIVAL ANALYSIS
Loading data from fact_user_analytics_snapshot...
Loaded 802 premium users

 Survival Data Summary:
  Total users:          802
  Churned (event=1):    415 (51.7%)
  Active (censored):    387 (48.3%)
  Avg duration:         542.8 days
  Median duration:      545.5 days


In [3]:
#Fitting Kaplan-Meier model
print("="*80)
print("FITTING KAPLAN-MEIER SURVIVAL MODEL")
print("="*80)

kmf = KaplanMeierFitter()
kmf.fit(durations=df['duration'], event_observed=df['event'], label='All Users')

median_survival = kmf.median_survival_time_
print(f"\n Overall Median Survival Time: {median_survival:.1f} days")

print("\n Survival Probabilities:")
for days in [30, 60, 90, 180, 365]:
    try:
        survival_prob = kmf.predict(days)
        print(f"  {days:3d} days: {survival_prob:.1%} chance of still being active")
    except:
        print(f"  {days:3d} days: N/A (beyond observation period)")

print("\n Survival Function Summary:")
print(kmf.survival_function_.describe())


FITTING KAPLAN-MEIER SURVIVAL MODEL

 Overall Median Survival Time: 789.0 days

 Survival Probabilities:
   30 days: 98.6% chance of still being active
   60 days: 96.9% chance of still being active
   90 days: 95.6% chance of still being active
  180 days: 91.6% chance of still being active
  365 days: 81.5% chance of still being active

 Survival Function Summary:
        All Users
count  581.000000
mean     0.655441
std      0.247013
min      0.000000
25%      0.479947
50%      0.707543
75%      0.862206
max      1.000000


In [4]:
#Comparing survival curves by engagement level
print("="*80)
print("SURVIVAL ANALYSIS BY ENGAGEMENT LEVEL")
print("="*80)

engagement_levels = df['engagement_level'].unique()

print(f"Analyzing {len(engagement_levels)} engagement segments...")

survival_stats = []

for level in engagement_levels:
    segment_df = df[df['engagement_level'] == level]
    
    if len(segment_df) > 0:
        kmf_segment = KaplanMeierFitter()
        kmf_segment.fit(
            durations=segment_df['duration'], 
            event_observed=segment_df['event'],
            label=level
        )
        
        median_time = kmf_segment.median_survival_time_
        
        survival_stats.append({
            'Engagement_Level': level,
            'User_Count': len(segment_df),
            'Churned_Count': segment_df['event'].sum(),
            'Churn_Rate': f"{segment_df['event'].mean():.1%}",
            'Median_Survival_Days': round(median_time, 1) if not np.isnan(median_time) else 'N/A'
        })

survival_stats_df = pd.DataFrame(survival_stats).sort_values('User_Count', ascending=False)

print("\n SURVIVAL STATISTICS BY ENGAGEMENT LEVEL:")
print(survival_stats_df.to_string(index=False))


SURVIVAL ANALYSIS BY ENGAGEMENT LEVEL
Analyzing 4 engagement segments...

 SURVIVAL STATISTICS BY ENGAGEMENT LEVEL:
Engagement_Level  User_Count  Churned_Count Churn_Rate  Median_Survival_Days
  Medium Engaged         254             94      37.0%                 937.0
  Highly Engaged         202              0       0.0%                   inf
         Dormant         184            184     100.0%                 586.0
         At Risk         162            137      84.6%                 606.0


In [5]:
#Fitteing Cox Proportional Hazards model
print("="*80)
print("FITTING COX PROPORTIONAL HAZARDS MODEL")
print("="*80)

cox_features = [
    'rfm_frequency',
    'rfm_monetary',
    'churn_probability'
]

cox_df = df[['duration', 'event'] + cox_features].copy()
cox_df = cox_df.dropna()

print(f"Training Cox model on {len(cox_df):,} users with {len(cox_features)} features...")

cph = CoxPHFitter()
cph.fit(cox_df, duration_col='duration', event_col='event')

print("\nCOX MODEL SUMMARY:")
print(cph.summary)

concordance = cph.concordance_index_
print(f"\nModel Performance:")
print(f"  Concordance Index: {concordance:.4f}")
print(f"  (0.5 = random, 1.0 = perfect)")

# Significant features
print("\nSIGNIFICANT FEATURES (p < 0.05):")
significant = cph.summary[cph.summary['p'] < 0.05]
if len(significant) > 0:
    print(significant[['coef', 'exp(coef)', 'p']])
else:
    print("No significant features at p < 0.05")


FITTING COX PROPORTIONAL HAZARDS MODEL
Training Cox model on 802 users with 3 features...

COX MODEL SUMMARY:
                       coef  exp(coef)  se(coef)  coef lower 95%  coef upper 95%  exp(coef) lower 95%  exp(coef) upper 95%  cmp to         z             p   -log2(p)
covariate                                                                                                                                                            
rfm_frequency     -0.057657   0.943973  0.008099       -0.073530       -0.041784             0.929108             0.959076     0.0 -7.119489  1.083276e-12  39.747737
rfm_monetary      -0.008296   0.991739  0.002462       -0.013122       -0.003469             0.986964             0.996537     0.0 -3.368799  7.549656e-04  10.371301
churn_probability  1.770148   5.871721  0.190582        1.396614        2.143682             4.041492             8.530788     0.0  9.288116  1.570437e-20  65.787396

Model Performance:
  Concordance Index: 0.7270
  (0.5 = ran

In [6]:
#Calculating survival metrics for each user
print("="*80)
print("PREDICTING SURVIVAL METRICS FOR ALL USERS")
print("="*80)

pred_df = df[cox_features].copy()
pred_df = pred_df.fillna(pred_df.median())

survival_functions = cph.predict_survival_function(pred_df)

median_survival_times = []
risk_90d = []

for idx in survival_functions.columns:
    sf = survival_functions[idx]
    
    try:
        median_time = sf[sf <= 0.5].index[0]
    except:
        median_time = sf.index[-1] 
    
    median_survival_times.append(median_time)
    
    try:
        survival_90d = sf.loc[90] if 90 in sf.index else sf.iloc[-1]
        risk_90d.append(1 - survival_90d) 
    except:
        risk_90d.append(0.5) 

df['survival_median_time_to_downgrade'] = median_survival_times
df['survival_risk_90d'] = risk_90d

print(f"Survival metrics calculated for {len(df):,} users")

print("\nSurvival Metrics Summary:")
print(f"  Median Time to Downgrade:")
print(f"    Mean:   {df['survival_median_time_to_downgrade'].mean():.1f} days")
print(f"    Median: {df['survival_median_time_to_downgrade'].median():.1f} days")
print(f"    Min:    {df['survival_median_time_to_downgrade'].min():.1f} days")
print(f"    Max:    {df['survival_median_time_to_downgrade'].max():.1f} days")

print(f"\n  90-Day Churn Risk:")
print(f"    Mean:   {df['survival_risk_90d'].mean():.3f}")
print(f"    Median: {df['survival_risk_90d'].median():.3f}")
print(f"    Min:    {df['survival_risk_90d'].min():.3f}")
print(f"    Max:    {df['survival_risk_90d'].max():.3f}")

# Risk distribution
print("\n90-Day Risk Distribution:")
risk_bands = pd.cut(df['survival_risk_90d'], bins=[0, 0.3, 0.6, 1.0], labels=['Low', 'Medium', 'High'])
print(risk_bands.value_counts().sort_index())


PREDICTING SURVIVAL METRICS FOR ALL USERS
Survival metrics calculated for 802 users

Survival Metrics Summary:
  Median Time to Downgrade:
    Mean:   854.7 days
    Median: 944.0 days
    Min:    287.0 days
    Max:    1096.0 days

  90-Day Churn Risk:
    Mean:   0.818
    Median: 0.901
    Min:    0.329
    Max:    1.000

90-Day Risk Distribution:
survival_risk_90d
Low         0
Medium    164
High      638
Name: count, dtype: int64


In [7]:
#Updating fact_user_analytics_snapshot with survival predictions
print("="*80)
print("UPDATING DATABASE WITH SURVIVAL PREDICTIONS")
print("="*80)

update_df = df[['user_key', 'survival_median_time_to_downgrade', 'survival_risk_90d']].copy()

print(f"Updating {len(update_df):,} user records...")

updated_count = 0
with SessionLocal() as session:
    for idx, row in update_df.iterrows():
        record = session.query(FactUserAnalyticsSnapshot).filter(
            FactUserAnalyticsSnapshot.user_key == int(row['user_key']),
            FactUserAnalyticsSnapshot.snapshot_date_key == snapshot_date_key
        ).first()
        
        if record:
            record.survival_median_time_to_downgrade = int(round(row['survival_median_time_to_downgrade']))
            record.survival_risk_90d = float(row['survival_risk_90d'])
            updated_count += 1
            
            if updated_count % 500 == 0:
                session.commit()
                print(f"  Updated {updated_count:,} records...")
    
    session.commit()

print(f"\nUpdated {updated_count:,} records in fact_user_analytics_snapshot")

with SessionLocal() as session:
    total = session.query(FactUserAnalyticsSnapshot).filter(
        FactUserAnalyticsSnapshot.snapshot_date_key == snapshot_date_key
    ).count()
    
    with_survival = session.query(FactUserAnalyticsSnapshot).filter(
        FactUserAnalyticsSnapshot.snapshot_date_key == snapshot_date_key,
        FactUserAnalyticsSnapshot.survival_median_time_to_downgrade.isnot(None)
    ).count()
    
    with_risk = session.query(FactUserAnalyticsSnapshot).filter(
        FactUserAnalyticsSnapshot.snapshot_date_key == snapshot_date_key,
        FactUserAnalyticsSnapshot.survival_risk_90d.isnot(None)
    ).count()
    
    print("\nVerification:")
    print(f"  Total records:                    {total:,}")
    print(f"  With survival_median_time:        {with_survival:,}")
    print(f"  With survival_risk_90d:           {with_risk:,}")

print("\nDatabase update complete!")


UPDATING DATABASE WITH SURVIVAL PREDICTIONS
Updating 802 user records...
  Updated 500 records...

Updated 802 records in fact_user_analytics_snapshot

Verification:
  Total records:                    1,000
  With survival_median_time:        802
  With survival_risk_90d:           802

Database update complete!


In [8]:
#Generating Insights
print("="*80)
print("ACTIONABLE INSIGHTS FROM SURVIVAL ANALYSIS")
print("="*80)

high_risk_soon = df[
    (df['survival_risk_90d'] > 0.6) & 
    (df['survival_median_time_to_downgrade'] < 90)
].sort_values('survival_risk_90d', ascending=False)

print(f"\nHIGH PRIORITY: {len(high_risk_soon)} users at high risk in next 90 days")
print("Top 10 users needing immediate intervention:")
print(high_risk_soon.head(10)[[
    'user_key', 
    'segment_label', 
    'churn_probability',
    'survival_median_time_to_downgrade',
    'survival_risk_90d'
]].to_string(index=False))

print("\nSURVIVAL METRICS BY RFM SEGMENT:")
segment_survival = df.groupby('segment_label').agg({
    'user_key': 'count',
    'survival_median_time_to_downgrade': 'mean',
    'survival_risk_90d': 'mean'
}).round(1)
segment_survival.columns = ['User_Count', 'Avg_Median_Survival_Days', 'Avg_90d_Risk']
segment_survival = segment_survival.sort_values('Avg_90d_Risk', ascending=False)
print(segment_survival)

ACTIONABLE INSIGHTS FROM SURVIVAL ANALYSIS

HIGH PRIORITY: 0 users at high risk in next 90 days
Top 10 users needing immediate intervention:
Empty DataFrame
Columns: [user_key, segment_label, churn_probability, survival_median_time_to_downgrade, survival_risk_90d]
Index: []

SURVIVAL METRICS BY RFM SEGMENT:
                            User_Count  Avg_Median_Survival_Days  Avg_90d_Risk
segment_label                                                                 
Casual Users                        47                     644.9           1.0
Declining Engagement                93                     702.1           1.0
Dormant Premium                     62                     487.3           1.0
High-Value at Risk                  69                     562.8           1.0
New Premium Users                   50                     785.9           1.0
Recently Churned                   122                     829.1           0.9
Engaged Subscribers                 76                    1