# MMM Modeling with Mediation Assumption

This notebook implements the core mediation-aware MMM model where Google spend mediates the relationship between social channels and revenue.

## Key Components:
1. **Two-Stage Modeling**: 
   - Stage 1: Social channels → Google spend
   - Stage 2: Google spend + direct variables → Revenue
2. **Causal Framework**: Explicit treatment of mediation effects
3. **Time Series Validation**: Proper cross-validation respecting temporal order
4. **Feature Engineering**: Adstock, saturation, and interaction effects


In [None]:
# Import necessary libraries
import sys
import os
sys.path.append('../src')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

# Import our custom modules
from data_preparation import DataPreparator
from mediation_model import MediationMMM
from utils import set_random_seed

# Set random seed for reproducibility
set_random_seed(42)

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)


In [None]:
# Load and prepare data
print("Loading and preparing data...")

# Load processed data
data = pd.read_csv('../data/raw/mmm_data.csv')
data['date'] = pd.to_datetime(data['date'])

# Prepare data with transformations
prep = DataPreparator(random_seed=42)
prepared_data = prep.prepare_data(data, apply_transformations=True)

print(f"Data prepared: {prepared_data.shape}")
print(f"Features available: {len(prepared_data.columns)}")
print(f"Original features: {len(data.columns)}")
print(f"New features created: {len(prepared_data.columns) - len(data.columns)}")


## 1. Define Feature Sets for Mediation Model

We need to define which features go into each stage of our mediation model:


In [None]:
# Define feature sets for mediation model
# Stage 1: Social channels → Google spend
social_features = [col for col in prepared_data.columns 
                  if any(social in col.lower() for social in ['facebook', 'tiktok', 'instagram', 'snapchat']) 
                  and 'spend' in col.lower()]

# Stage 2: Google spend + direct variables → Revenue
direct_features = [col for col in prepared_data.columns 
                  if any(direct in col.lower() for direct in ['email', 'sms', 'price', 'promo', 'follower'])
                  or col == 'google_spend']

# Add interaction and lag features
interaction_features = [col for col in prepared_data.columns if '_x_' in col]
lag_features = [col for col in prepared_data.columns if 'lag_' in col]
ma_features = [col for col in prepared_data.columns if '_ma_' in col]

# Combine all features for stage 2
stage2_features = direct_features + interaction_features + lag_features + ma_features

print("Stage 1 Features (Social → Google):")
for i, feature in enumerate(social_features, 1):
    print(f"  {i}. {feature}")

print(f"\nStage 2 Features (Google + Direct → Revenue): {len(stage2_features)} features")
print("Direct features:")
for feature in direct_features:
    print(f"  - {feature}")
print("Interaction features:")
for feature in interaction_features:
    print(f"  - {feature}")
print("Lag features:")
for feature in lag_features:
    print(f"  - {feature}")
print("Moving average features:")
for feature in ma_features:
    print(f"  - {feature}")


## 2. Fit Mediation Model

Now let's fit our two-stage mediation model:


In [None]:
# Initialize and fit mediation model
print("Fitting mediation-aware MMM model...")

model = MediationMMM(random_seed=42)

# Fit the model
results = model.fit(
    data=prepared_data,
    target_col='revenue',
    google_col='google_spend',
    social_cols=social_features,
    direct_cols=direct_features
)

print("✅ Model fitting complete!")
print(f"Stage 1 R²: {results['stage1_metrics']['r2']:.4f}")
print(f"Stage 2 R²: {results['stage2_metrics']['r2']:.4f}")
print(f"Overall MAPE: {results['stage2_metrics']['mape']:.2f}%")


## 3. Mediation Effects Analysis

Let's examine the mediation effects - how social channels influence revenue through Google spend:


In [None]:
# Analyze mediation effects
mediation_effects = results['mediation_effects']

print("MEDIATION EFFECTS ANALYSIS")
print("=" * 50)
print("How social channels influence revenue through Google spend:")
print()

for channel, effects in mediation_effects.items():
    print(f"{channel.replace('_', ' ').title()}:")
    print(f"  Direct Effect:     {effects['direct_effect']:8.2f}")
    print(f"  Indirect Effect:   {effects['indirect_effect']:8.2f}")
    print(f"  Total Effect:      {effects['total_effect']:8.2f}")
    print(f"  Mediation Ratio:   {effects['mediation_ratio']:8.2%}")
    print()

# Calculate overall mediation strength
total_mediation = sum(abs(effects['indirect_effect']) for effects in mediation_effects.values())
total_direct = sum(abs(effects['direct_effect']) for effects in mediation_effects.values())
overall_mediation_ratio = total_mediation / (total_mediation + total_direct + 1e-8)

print(f"Overall Mediation Strength: {overall_mediation_ratio:.2%}")
print(f"Total Indirect Effects: {total_mediation:.2f}")
print(f"Total Direct Effects: {total_direct:.2f}")


## 4. Feature Importance Analysis

Let's examine which features are most important in each stage:


In [None]:
# Feature importance analysis
feature_importance = results['feature_importance']

# Stage 1 importance (Social → Google)
print("STAGE 1 FEATURE IMPORTANCE (Social Channels → Google Spend)")
print("=" * 60)
stage1_importance = feature_importance['stage1']
sorted_stage1 = sorted(stage1_importance.items(), key=lambda x: x[1], reverse=True)

for i, (feature, importance) in enumerate(sorted_stage1, 1):
    print(f"{i:2d}. {feature.replace('_', ' ').title():25s} {importance:8.4f}")

print(f"\nSTAGE 2 FEATURE IMPORTANCE (Google + Direct → Revenue)")
print("=" * 60)
stage2_importance = feature_importance['stage2']
sorted_stage2 = sorted(stage2_importance.items(), key=lambda x: x[1], reverse=True)

for i, (feature, importance) in enumerate(sorted_stage2, 1):
    print(f"{i:2d}. {feature.replace('_', ' ').title():25s} {importance:8.4f}")

# Visualize feature importance
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Stage 1
features1 = [f.replace('_', ' ').title() for f, _ in sorted_stage1]
importance1 = [imp for _, imp in sorted_stage1]
ax1.barh(features1, importance1, color='skyblue')
ax1.set_title('Stage 1: Social Channels → Google Spend')
ax1.set_xlabel('Feature Importance (|Coefficient|)')

# Stage 2 (top 10)
top_features2 = sorted_stage2[:10]
features2 = [f.replace('_', ' ').title() for f, _ in top_features2]
importance2 = [imp for _, imp in top_features2]
ax2.barh(features2, importance2, color='lightcoral')
ax2.set_title('Stage 2: Google + Direct → Revenue (Top 10)')
ax2.set_xlabel('Feature Importance (|Coefficient|)')

plt.tight_layout()
plt.show()


## 5. Time Series Cross-Validation

Let's perform proper time series cross-validation to assess model stability:


In [None]:
# Perform time series cross-validation
print("Performing time series cross-validation...")

cv_results = model.cross_validate(
    data=prepared_data,
    n_splits=5,
    target_col='revenue'
)

print("\nCROSS-VALIDATION RESULTS")
print("=" * 40)

# Stage 1 results
stage1_r2_mean = cv_results['stage1_r2']['mean']
stage1_r2_std = cv_results['stage1_r2']['std']
stage1_rmse_mean = cv_results['stage1_rmse']['mean']
stage1_rmse_std = cv_results['stage1_rmse']['std']

print(f"Stage 1 (Social → Google):")
print(f"  R²: {stage1_r2_mean:.4f} ± {stage1_r2_std:.4f}")
print(f"  RMSE: {stage1_rmse_mean:.2f} ± {stage1_rmse_std:.2f}")

# Stage 2 results
stage2_r2_mean = cv_results['stage2_r2']['mean']
stage2_r2_std = cv_results['stage2_r2']['std']
stage2_rmse_mean = cv_results['stage2_rmse']['mean']
stage2_rmse_std = cv_results['stage2_rmse']['std']

print(f"\nStage 2 (Google + Direct → Revenue):")
print(f"  R²: {stage2_r2_mean:.4f} ± {stage2_r2_std:.4f}")
print(f"  RMSE: {stage2_rmse_mean:.2f} ± {stage2_rmse_std:.2f}")

# Visualize CV results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Stage 1 CV results
ax1.plot(range(1, 6), cv_results['stage1_r2']['values'], 'o-', label='R²', color='blue')
ax1.plot(range(1, 6), cv_results['stage1_rmse']['values'], 's-', label='RMSE', color='red')
ax1.set_title('Stage 1 CV Results (Social → Google)')
ax1.set_xlabel('Fold')
ax1.set_ylabel('Metric Value')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Stage 2 CV results
ax2.plot(range(1, 6), cv_results['stage2_r2']['values'], 'o-', label='R²', color='blue')
ax2.plot(range(1, 6), cv_results['stage2_rmse']['values'], 's-', label='RMSE', color='red')
ax2.set_title('Stage 2 CV Results (Google + Direct → Revenue)')
ax2.set_xlabel('Fold')
ax2.set_ylabel('Metric Value')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


## 6. Model Predictions and Visualization

Let's visualize the model predictions and compare them with actual values:


In [None]:
# Get predictions
predictions = model.predict(prepared_data, social_features, stage2_features)

# Create prediction comparison plots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Actual vs Predicted Revenue
axes[0, 0].scatter(prepared_data['revenue'], predictions['revenue_predictions'], alpha=0.6, color='blue')
axes[0, 0].plot([prepared_data['revenue'].min(), prepared_data['revenue'].max()], 
                [prepared_data['revenue'].min(), prepared_data['revenue'].max()], 'r--', alpha=0.8)
axes[0, 0].set_xlabel('Actual Revenue')
axes[0, 0].set_ylabel('Predicted Revenue')
axes[0, 0].set_title('Actual vs Predicted Revenue')
axes[0, 0].grid(True, alpha=0.3)

# Revenue over time
axes[0, 1].plot(prepared_data['date'], prepared_data['revenue'], label='Actual', color='blue', linewidth=2)
axes[0, 1].plot(prepared_data['date'], predictions['revenue_predictions'], label='Predicted', color='red', linewidth=2, alpha=0.8)
axes[0, 1].set_xlabel('Date')
axes[0, 1].set_ylabel('Revenue')
axes[0, 1].set_title('Revenue Over Time')
axes[0, 1].legend()
axes[0, 1].tick_params(axis='x', rotation=45)
axes[0, 1].grid(True, alpha=0.3)

# Google spend predictions
axes[1, 0].scatter(prepared_data['google_spend'], predictions['google_predictions'], alpha=0.6, color='green')
axes[1, 0].plot([prepared_data['google_spend'].min(), prepared_data['google_spend'].max()], 
                [prepared_data['google_spend'].min(), prepared_data['google_spend'].max()], 'r--', alpha=0.8)
axes[1, 0].set_xlabel('Actual Google Spend')
axes[1, 0].set_ylabel('Predicted Google Spend')
axes[1, 0].set_title('Actual vs Predicted Google Spend')
axes[1, 0].grid(True, alpha=0.3)

# Residuals over time
residuals = prepared_data['revenue'] - predictions['revenue_predictions']
axes[1, 1].plot(prepared_data['date'], residuals, color='purple', linewidth=1)
axes[1, 1].axhline(y=0, color='red', linestyle='--', alpha=0.8)
axes[1, 1].set_xlabel('Date')
axes[1, 1].set_ylabel('Residuals')
axes[1, 1].set_title('Revenue Residuals Over Time')
axes[1, 1].tick_params(axis='x', rotation=45)
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print prediction accuracy metrics
print("PREDICTION ACCURACY METRICS")
print("=" * 30)
print(f"Revenue R²: {results['stage2_metrics']['r2']:.4f}")
print(f"Revenue RMSE: ${results['stage2_metrics']['rmse']:,.0f}")
print(f"Revenue MAPE: {results['stage2_metrics']['mape']:.2f}%")
print(f"Google Spend R²: {results['stage1_metrics']['r2']:.4f}")
print(f"Google Spend RMSE: ${results['stage1_metrics']['rmse']:,.0f}")
print(f"Google Spend MAPE: {results['stage1_metrics']['mape']:.2f}%")


## 7. Save Model and Results

Let's save our trained model and results for further analysis:


In [None]:
# Save the trained model
model.save_model('../results/mediation_mmm_model.pkl')
print("✅ Model saved to ../results/mediation_mmm_model.pkl")

# Save results summary
import json

results_summary = {
    'model_performance': {
        'stage1_r2': results['stage1_metrics']['r2'],
        'stage1_rmse': results['stage1_metrics']['rmse'],
        'stage1_mape': results['stage1_metrics']['mape'],
        'stage2_r2': results['stage2_metrics']['r2'],
        'stage2_rmse': results['stage2_metrics']['rmse'],
        'stage2_mape': results['stage2_metrics']['mape']
    },
    'cv_performance': {
        'stage1_r2_mean': cv_results['stage1_r2']['mean'],
        'stage1_r2_std': cv_results['stage1_r2']['std'],
        'stage2_r2_mean': cv_results['stage2_r2']['mean'],
        'stage2_r2_std': cv_results['stage2_r2']['std']
    },
    'mediation_effects': mediation_effects,
    'feature_importance': feature_importance
}

with open('../results/model_results_summary.json', 'w') as f:
    json.dump(results_summary, f, indent=2, default=str)

print("✅ Results summary saved to ../results/model_results_summary.json")

# Save prepared data
prepared_data.to_csv('../data/processed/prepared_mmm_data.csv', index=False)
print("✅ Prepared data saved to ../data/processed/prepared_mmm_data.csv")

print("\n📊 MODEL SUMMARY")
print("=" * 20)
print(f"✅ Two-stage mediation model successfully trained")
print(f"✅ Stage 1 (Social → Google): R² = {results['stage1_metrics']['r2']:.4f}")
print(f"✅ Stage 2 (Google + Direct → Revenue): R² = {results['stage2_metrics']['r2']:.4f}")
print(f"✅ Cross-validation completed with {len(cv_results['stage1_r2']['values'])} folds")
print(f"✅ Mediation effects quantified for {len(mediation_effects)} social channels")
print(f"✅ Model and results saved for further analysis")
