# Cohort Retention Analysis

This notebook demonstrates customer retention analysis using cohort methodology - a critical skill for data analysts and business intelligence professionals.

## What You'll Learn
- How to assign customers to cohorts based on first purchase
- Building retention matrices
- Creating cohort heatmaps with seaborn
- Calculating and visualizing retention curves
- Deriving actionable insights from retention data

## Business Context
Retention is one of the most important metrics for any business. Understanding when and why customers churn helps prioritize product improvements and marketing efforts.

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import warnings

warnings.filterwarnings('ignore')

# Style configuration
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 11

print("Libraries loaded successfully!")

## 1. Data Loading & Preparation

In [None]:
# Load transaction data
# Using sample data - replace with your actual data source
df = pd.read_csv('../data/samples/revenue_sample.csv')

# Display basic info
print(f"Dataset shape: {df.shape}")
print(f"\nColumns: {df.columns.tolist()}")
df.head()

In [None]:
# Data preparation
# Ensure date column is datetime
date_cols = [col for col in df.columns if 'date' in col.lower()]
if date_cols:
    df['transaction_date'] = pd.to_datetime(df[date_cols[0]])
else:
    # Create synthetic dates if none exist
    np.random.seed(42)
    start_date = datetime(2024, 1, 1)
    df['transaction_date'] = [start_date + timedelta(days=np.random.randint(0, 365)) for _ in range(len(df))]

# Ensure customer_id exists
if 'customer_id' not in df.columns:
    # Create synthetic customer IDs
    np.random.seed(42)
    df['customer_id'] = [f'CUST_{np.random.randint(1, 200):04d}' for _ in range(len(df))]

# Ensure amount column exists
amount_cols = [col for col in df.columns if 'amount' in col.lower() or 'revenue' in col.lower() or 'price' in col.lower()]
if amount_cols:
    df['amount'] = pd.to_numeric(df[amount_cols[0]], errors='coerce')
else:
    df['amount'] = np.random.uniform(50, 500, len(df))

print(f"Date range: {df['transaction_date'].min()} to {df['transaction_date'].max()}")
print(f"Unique customers: {df['customer_id'].nunique()}")
print(f"Total transactions: {len(df)}")

## 2. Cohort Assignment

A cohort is a group of customers who share a common characteristic - in this case, the month of their first purchase.

In [None]:
# Assign cohorts based on first purchase month
def assign_cohort(df):
    """
    Assign each customer to a cohort based on their first transaction.
    """
    # Get first purchase date for each customer
    customer_first_purchase = df.groupby('customer_id')['transaction_date'].min().reset_index()
    customer_first_purchase.columns = ['customer_id', 'first_purchase_date']
    
    # Create cohort month
    customer_first_purchase['cohort_month'] = customer_first_purchase['first_purchase_date'].dt.to_period('M')
    
    # Merge back to main dataframe
    df = df.merge(customer_first_purchase[['customer_id', 'cohort_month']], on='customer_id')
    
    # Calculate transaction month
    df['transaction_month'] = df['transaction_date'].dt.to_period('M')
    
    # Calculate period number (months since cohort)
    df['period_number'] = (df['transaction_month'] - df['cohort_month']).apply(lambda x: x.n if hasattr(x, 'n') else 0)
    
    return df

df = assign_cohort(df)

# Show cohort distribution
cohort_counts = df.groupby('cohort_month')['customer_id'].nunique().reset_index()
cohort_counts.columns = ['cohort_month', 'cohort_size']
print("Cohort Sizes:")
print(cohort_counts.to_string(index=False))

In [None]:
# Visualize cohort sizes
fig, ax = plt.subplots(figsize=(10, 5))

cohort_counts['cohort_month_str'] = cohort_counts['cohort_month'].astype(str)
bars = ax.bar(cohort_counts['cohort_month_str'], cohort_counts['cohort_size'], color='steelblue', edgecolor='navy')

# Add value labels
for bar, val in zip(bars, cohort_counts['cohort_size']):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
            str(val), ha='center', va='bottom', fontsize=10)

ax.set_xlabel('Cohort Month', fontsize=12)
ax.set_ylabel('Number of New Customers', fontsize=12)
ax.set_title('Customer Cohort Sizes by Month', fontsize=14, fontweight='bold')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig('../docs/visualizations/cohort_sizes.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Building the Retention Matrix

In [None]:
def build_retention_matrix(df, max_periods=12):
    """
    Build a retention matrix showing % of customers retained each period.
    """
    # Count unique customers per cohort per period
    cohort_data = df.groupby(['cohort_month', 'period_number'])['customer_id'].nunique().reset_index()
    cohort_data.columns = ['cohort_month', 'period_number', 'customers']
    
    # Get cohort sizes (period 0)
    cohort_sizes = cohort_data[cohort_data['period_number'] == 0][['cohort_month', 'customers']]
    cohort_sizes.columns = ['cohort_month', 'cohort_size']
    
    # Merge to get retention rates
    cohort_data = cohort_data.merge(cohort_sizes, on='cohort_month')
    cohort_data['retention_rate'] = cohort_data['customers'] / cohort_data['cohort_size'] * 100
    
    # Pivot to create matrix
    retention_matrix = cohort_data.pivot_table(
        index='cohort_month',
        columns='period_number',
        values='retention_rate'
    )
    
    # Limit periods
    retention_matrix = retention_matrix[[c for c in retention_matrix.columns if c <= max_periods]]
    
    return retention_matrix, cohort_sizes

retention_matrix, cohort_sizes = build_retention_matrix(df)
print("Retention Matrix (%)")
print(retention_matrix.round(1).to_string())

## 4. Cohort Retention Heatmap

The heatmap is a powerful visualization for understanding retention patterns at a glance.

In [None]:
# Create retention heatmap
fig, ax = plt.subplots(figsize=(14, 8))

# Format index for display
retention_display = retention_matrix.copy()
retention_display.index = retention_display.index.astype(str)

# Create heatmap
sns.heatmap(
    retention_display,
    annot=True,
    fmt='.1f',
    cmap='YlGnBu',
    vmin=0,
    vmax=100,
    linewidths=0.5,
    ax=ax,
    cbar_kws={'label': 'Retention Rate (%)'}
)

ax.set_xlabel('Period (Months Since First Purchase)', fontsize=12)
ax.set_ylabel('Cohort Month', fontsize=12)
ax.set_title('Customer Retention by Cohort\n(Percentage of Customers Still Active)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig('../docs/visualizations/cohort_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Retention Curves

Retention curves show how retention decays over time for each cohort.

In [None]:
# Plot retention curves for each cohort
fig, ax = plt.subplots(figsize=(12, 7))

colors = plt.cm.viridis(np.linspace(0, 0.9, len(retention_matrix)))

for i, (cohort, row) in enumerate(retention_matrix.iterrows()):
    values = row.dropna()
    ax.plot(values.index, values.values, marker='o', markersize=4, 
            linewidth=2, label=str(cohort), color=colors[i], alpha=0.8)

ax.set_xlabel('Period (Months Since First Purchase)', fontsize=12)
ax.set_ylabel('Retention Rate (%)', fontsize=12)
ax.set_title('Retention Curves by Cohort', fontsize=14, fontweight='bold')
ax.legend(title='Cohort', bbox_to_anchor=(1.02, 1), loc='upper left')
ax.set_ylim(0, 105)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../docs/visualizations/retention_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Average retention curve across all cohorts
avg_retention = retention_matrix.mean(axis=0)

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(avg_retention.index, avg_retention.values, marker='o', markersize=8, 
        linewidth=3, color='#2E86AB', label='Average Retention')

# Add confidence band (using std)
std_retention = retention_matrix.std(axis=0)
ax.fill_between(avg_retention.index, 
                avg_retention.values - std_retention.values,
                avg_retention.values + std_retention.values,
                alpha=0.2, color='#2E86AB')

# Add annotations for key periods
for period in [1, 3, 6]:
    if period in avg_retention.index:
        val = avg_retention[period]
        ax.annotate(f'{val:.1f}%', (period, val), textcoords="offset points", 
                   xytext=(0, 10), ha='center', fontsize=10, fontweight='bold')

ax.set_xlabel('Period (Months Since First Purchase)', fontsize=12)
ax.set_ylabel('Average Retention Rate (%)', fontsize=12)
ax.set_title('Average Retention Curve with Variance Band', fontsize=14, fontweight='bold')
ax.set_ylim(0, 105)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../docs/visualizations/avg_retention_curve.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Churn Analysis

In [None]:
# Calculate period-over-period churn
def calculate_churn(retention_matrix):
    """
    Calculate churn rate between consecutive periods.
    """
    churn_matrix = retention_matrix.diff(axis=1) * -1  # Negative change = churn
    return churn_matrix

churn_matrix = calculate_churn(retention_matrix)

# Average churn by period
avg_churn = churn_matrix.mean(axis=0)

fig, ax = plt.subplots(figsize=(10, 5))

bars = ax.bar(avg_churn.index[1:], avg_churn.values[1:], color='coral', edgecolor='darkred')

ax.set_xlabel('Period Transition', fontsize=12)
ax.set_ylabel('Average Churn Rate (%)', fontsize=12)
ax.set_title('Average Churn Rate by Period\n(Percentage Points Lost)', fontsize=14, fontweight='bold')
ax.axhline(y=avg_churn.values[1:].mean(), color='red', linestyle='--', 
           label=f'Overall Avg: {avg_churn.values[1:].mean():.1f}%')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nKey Insight: Highest churn occurs in Period 1 ({avg_churn.iloc[1]:.1f}%)")
print("This is typical - the first month after acquisition is critical for retention.")

## 7. Cohort Revenue Analysis

In [None]:
# Revenue by cohort over time
cohort_revenue = df.groupby(['cohort_month', 'period_number'])['amount'].sum().reset_index()
cohort_revenue.columns = ['cohort_month', 'period_number', 'revenue']

# Pivot to matrix
revenue_matrix = cohort_revenue.pivot_table(
    index='cohort_month',
    columns='period_number',
    values='revenue'
)

# Cumulative revenue by cohort
cumulative_revenue = revenue_matrix.cumsum(axis=1)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Revenue per period heatmap
ax1 = axes[0]
revenue_display = revenue_matrix.copy()
revenue_display.index = revenue_display.index.astype(str)
sns.heatmap(revenue_display.round(0), annot=True, fmt='.0f', cmap='Greens', 
            ax=ax1, cbar_kws={'label': 'Revenue ($)'})
ax1.set_title('Revenue by Cohort and Period', fontsize=12, fontweight='bold')
ax1.set_xlabel('Period')
ax1.set_ylabel('Cohort')

# Cumulative revenue heatmap
ax2 = axes[1]
cumulative_display = cumulative_revenue.copy()
cumulative_display.index = cumulative_display.index.astype(str)
sns.heatmap(cumulative_display.round(0), annot=True, fmt='.0f', cmap='Blues',
            ax=ax2, cbar_kws={'label': 'Cumulative Revenue ($)'})
ax2.set_title('Cumulative Revenue by Cohort', fontsize=12, fontweight='bold')
ax2.set_xlabel('Period')
ax2.set_ylabel('Cohort')

plt.tight_layout()
plt.show()

## 8. Key Insights & Recommendations

In [None]:
# Summary statistics
print("="*60)
print("COHORT RETENTION ANALYSIS - KEY FINDINGS")
print("="*60)

# 1. Overall retention metrics
month_1_retention = avg_retention.iloc[1] if len(avg_retention) > 1 else 0
month_3_retention = avg_retention.iloc[3] if len(avg_retention) > 3 else 0
month_6_retention = avg_retention.iloc[6] if len(avg_retention) > 6 else 0

print(f"\n1. RETENTION BENCHMARKS")
print(f"   - Month 1 Retention: {month_1_retention:.1f}%")
print(f"   - Month 3 Retention: {month_3_retention:.1f}%")
print(f"   - Month 6 Retention: {month_6_retention:.1f}%")

# 2. Best and worst cohorts
if len(retention_matrix) > 0 and 1 in retention_matrix.columns:
    best_cohort = retention_matrix[1].idxmax()
    worst_cohort = retention_matrix[1].idxmin()
    print(f"\n2. COHORT PERFORMANCE")
    print(f"   - Best performing cohort: {best_cohort} ({retention_matrix.loc[best_cohort, 1]:.1f}% M1 retention)")
    print(f"   - Worst performing cohort: {worst_cohort} ({retention_matrix.loc[worst_cohort, 1]:.1f}% M1 retention)")

# 3. Churn patterns
print(f"\n3. CHURN PATTERNS")
print(f"   - Highest churn period: Month 0 to Month 1")
print(f"   - Average period-over-period churn: {avg_churn.iloc[1:].mean():.1f}%")

# 4. Recommendations
print(f"\n4. RECOMMENDATIONS")
print(f"   - Focus onboarding efforts in first 30 days (critical drop-off period)")
print(f"   - Investigate what made cohort {best_cohort if 'best_cohort' in dir() else 'top performers'} successful")
print(f"   - Implement re-engagement campaign for customers approaching 60-day inactivity")
print(f"   - Set up automated check-ins at Day 7, 14, and 30 to improve early retention")

print("\n" + "="*60)

In [None]:
# Export summary data for reporting
summary_data = {
    'retention_matrix': retention_matrix,
    'avg_retention': avg_retention,
    'cohort_sizes': cohort_sizes
}

# Save to CSV
retention_matrix.to_csv('../data/samples/cohort_retention_matrix.csv')
print("Retention matrix exported to data/samples/cohort_retention_matrix.csv")