# Basic Media Mix Modeling

This notebook introduces fundamental MMM concepts:
- Adstock (carryover effects)
- Saturation curves
- Basic regression models
- Model interpretation

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.metrics import mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split
import scipy.optimize as opt

# Load the preprocessed data
try:
    df = pd.read_csv('../data/processed_mmm_data.csv')
    print("Loaded preprocessed data successfully!")
except:
    print("Could not load preprocessed data. Please run the data preparation notebook first.")
    # Create sample data for demonstration
    np.random.seed(42)
    dates = pd.date_range(start='2020-01-01', periods=200, freq='W')
    df = pd.DataFrame({
        'date': dates,
        'tv_spend': np.random.gamma(2, 5000, 200),
        'digital_spend': np.random.gamma(3, 3000, 200),
        'radio_spend': np.random.gamma(1.5, 2000, 200),
        'print_spend': np.random.gamma(1, 1000, 200),
        'conversions': np.random.poisson(1500, 200)
    })

## 1. Adstock Transformation

Adstock models the carryover effect of marketing activities.

In [None]:
def adstock_transformation(x, decay_rate=0.5, max_lag=8):
    """
    Apply adstock transformation to a media channel.
    
    Parameters:
    x: array-like, media spend values
    decay_rate: float, rate at which effect decays (0-1)
    max_lag: int, maximum number of periods for carryover
    """
    adstocked = np.zeros_like(x, dtype=float)
    
    for i in range(len(x)):
        for j in range(max_lag + 1):
            if i - j >= 0:
                adstocked[i] += x[i - j] * (decay_rate ** j)
    
    return adstocked

# Example: Apply adstock to TV spend
tv_spend = df['tv_spend'].values
tv_adstocked = adstock_transformation(tv_spend, decay_rate=0.7)

# Visualize the effect
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(tv_spend[:52], label='Original TV Spend', alpha=0.7)
plt.plot(tv_adstocked[:52], label='Adstocked TV Spend', alpha=0.7)
plt.title('Adstock Effect on TV Spend (First Year)')
plt.xlabel('Week')
plt.ylabel('Spend')
plt.legend()

plt.subplot(1, 2, 2)
decay_rates = np.arange(0, 1.1, 0.1)
sample_week = tv_spend[50:60]  # 10 weeks sample
for rate in [0.3, 0.5, 0.7, 0.9]:
    adstocked_sample = adstock_transformation(sample_week, decay_rate=rate)
    plt.plot(adstocked_sample, label=f'Decay Rate: {rate}', marker='o')

plt.title('Effect of Different Decay Rates')
plt.xlabel('Week')
plt.ylabel('Adstocked Spend')
plt.legend()
plt.tight_layout()
plt.show()

## 2. Saturation Curves

Saturation curves model diminishing returns of marketing spend.

In [None]:
def hill_saturation(x, half_saturation=1.0, shape=1.0):
    """
    Hill saturation curve (S-curve transformation).
    
    Parameters:
    x: array-like, input values (e.g., adstocked spend)
    half_saturation: float, point at which curve reaches 50% of maximum
    shape: float, controls the steepness of the curve
    """
    return x ** shape / (half_saturation ** shape + x ** shape)

def diminishing_returns(x, alpha=0.5):
    """
    Simple diminishing returns transformation.
    
    Parameters:
    x: array-like, input values
    alpha: float, saturation parameter (0-1)
    """
    return x ** alpha

# Visualize different saturation curves
x_range = np.linspace(0, 10000, 1000)

plt.figure(figsize=(15, 5))

# Hill saturation with different parameters
plt.subplot(1, 3, 1)
for half_sat in [2000, 4000, 6000]:
    y = hill_saturation(x_range, half_saturation=half_sat, shape=2)
    plt.plot(x_range, y, label=f'Half-sat: {half_sat}')
plt.title('Hill Saturation - Different Half-saturation Points')
plt.xlabel('Spend')
plt.ylabel('Saturated Effect')
plt.legend()

# Hill saturation with different shapes
plt.subplot(1, 3, 2)
for shape in [0.5, 1.0, 2.0, 4.0]:
    y = hill_saturation(x_range, half_saturation=4000, shape=shape)
    plt.plot(x_range, y, label=f'Shape: {shape}')
plt.title('Hill Saturation - Different Shapes')
plt.xlabel('Spend')
plt.ylabel('Saturated Effect')
plt.legend()

# Diminishing returns
plt.subplot(1, 3, 3)
for alpha in [0.3, 0.5, 0.7, 0.9]:
    y = diminishing_returns(x_range/1000, alpha=alpha)  # Scale for visibility
    plt.plot(x_range, y, label=f'Alpha: {alpha}')
plt.title('Diminishing Returns Curves')
plt.xlabel('Spend')
plt.ylabel('Effect')
plt.legend()

plt.tight_layout()
plt.show()

## 3. Basic MMM Model Implementation

In [None]:
class BasicMMM:
    def __init__(self):
        self.model = None
        self.channels = ['tv_spend', 'digital_spend', 'radio_spend', 'print_spend']
        self.adstock_params = {}
        self.saturation_params = {}
        
    def transform_media(self, data, channel, adstock_decay=0.5, saturation_alpha=0.5):
        """
        Apply adstock and saturation transformations to a media channel.
        """
        # Step 1: Adstock transformation
        adstocked = adstock_transformation(data[channel].values, decay_rate=adstock_decay)
        
        # Step 2: Saturation transformation
        saturated = diminishing_returns(adstocked, alpha=saturation_alpha)
        
        return saturated
    
    def prepare_features(self, data, params=None):
        """
        Prepare features with media transformations.
        """
        if params is None:
            # Default parameters
            params = {
                'tv_spend': {'adstock': 0.7, 'saturation': 0.6},
                'digital_spend': {'adstock': 0.4, 'saturation': 0.8},
                'radio_spend': {'adstock': 0.6, 'saturation': 0.7},
                'print_spend': {'adstock': 0.8, 'saturation': 0.5}
            }
        
        features = pd.DataFrame(index=data.index)
        
        # Transform each media channel
        for channel in self.channels:
            if channel in data.columns:
                transformed = self.transform_media(
                    data, channel,
                    adstock_decay=params[channel]['adstock'],
                    saturation_alpha=params[channel]['saturation']
                )
                features[f'{channel}_transformed'] = transformed
        
        # Add trend and seasonality if available
        if 'date' in data.columns:
            features['trend'] = np.arange(len(data))
            # Simple seasonality (can be improved)
            features['sin_seasonality'] = np.sin(2 * np.pi * np.arange(len(data)) / 52)
            features['cos_seasonality'] = np.cos(2 * np.pi * np.arange(len(data)) / 52)
        
        return features
    
    def fit(self, data, target_col='conversions', alpha=1.0):
        """
        Fit the MMM model.
        """
        X = self.prepare_features(data)
        y = data[target_col]
        
        # Use Ridge regression for regularization
        self.model = Ridge(alpha=alpha)
        self.model.fit(X, y)
        
        # Store feature names
        self.feature_names = X.columns.tolist()
        
        return self
    
    def predict(self, data):
        """
        Make predictions.
        """
        X = self.prepare_features(data)
        return self.model.predict(X)
    
    def get_coefficients(self):
        """
        Get model coefficients.
        """
        if self.model is None:
            return None
        
        return pd.DataFrame({
            'feature': self.feature_names,
            'coefficient': self.model.coef_
        }).sort_values('coefficient', key=abs, ascending=False)

# Example usage
mmm = BasicMMM()

# Split data
train_data = df.iloc[:int(0.8 * len(df))]
test_data = df.iloc[int(0.8 * len(df)):]

# Fit model
mmm.fit(train_data)

# Make predictions
train_pred = mmm.predict(train_data)
test_pred = mmm.predict(test_data)

print(f"Train R²: {r2_score(train_data['conversions'], train_pred):.3f}")
print(f"Test R²: {r2_score(test_data['conversions'], test_pred):.3f}")
print(f"Test MAE: {mean_absolute_error(test_data['conversions'], test_pred):.0f}")

## 4. Model Interpretation

In [None]:
# Get and visualize coefficients
coefficients = mmm.get_coefficients()
print("Model Coefficients:")
print(coefficients)

# Plot coefficients
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
media_coefs = coefficients[coefficients['feature'].str.contains('_transformed')]
plt.barh(media_coefs['feature'], media_coefs['coefficient'])
plt.title('Media Channel Coefficients')
plt.xlabel('Coefficient Value')

plt.subplot(1, 2, 2)
# Plot actual vs predicted
plt.scatter(test_data['conversions'], test_pred, alpha=0.6)
plt.plot([test_data['conversions'].min(), test_data['conversions'].max()], 
         [test_data['conversions'].min(), test_data['conversions'].max()], 'r--')
plt.xlabel('Actual Conversions')
plt.ylabel('Predicted Conversions')
plt.title('Actual vs Predicted (Test Set)')

plt.tight_layout()
plt.show()

## 5. Media Contribution Analysis

In [None]:
def calculate_media_contributions(mmm_model, data):
    """
    Calculate the contribution of each media channel to total conversions.
    """
    features = mmm_model.prepare_features(data)
    coefficients = mmm_model.model.coef_
    
    contributions = {}
    total_pred = mmm_model.predict(data)
    
    # Calculate contribution for each media channel
    for i, feature in enumerate(mmm_model.feature_names):
        if '_transformed' in feature:
            channel_name = feature.replace('_transformed', '')
            contribution = features[feature] * coefficients[i]
            contributions[channel_name] = contribution.sum()
    
    return contributions

# Calculate contributions
contributions = calculate_media_contributions(mmm, test_data)

# Visualize contributions
plt.figure(figsize=(10, 6))
channels = list(contributions.keys())
values = list(contributions.values())

plt.pie(values, labels=channels, autopct='%1.1f%%', startangle=90)
plt.title('Media Channel Contribution to Conversions')
plt.axis('equal')
plt.show()

print("Channel Contributions:")
for channel, contrib in contributions.items():
    print(f"{channel}: {contrib:,.0f} conversions")

## Exercise: Parameter Optimization

Try to optimize the adstock and saturation parameters:

In [None]:
# TODO: Implement parameter optimization
# Hint: Use scipy.optimize to find the best adstock and saturation parameters
# that minimize prediction error

def objective_function(params, data, target):
    """
    Objective function for optimization.
    params: array of [tv_adstock, tv_sat, digital_adstock, digital_sat, ...]
    """
    # Your implementation here
    pass

# Example optimization call:
# result = opt.minimize(objective_function, initial_params, args=(train_data, 'conversions'))

print("Implement parameter optimization above!")

## Next Steps

In the next module (`04-optimization/`), we'll cover:
- Budget allocation optimization
- Scenario planning
- ROI and ROAS calculations
- Advanced optimization techniques

You now have a basic understanding of MMM modeling!