# üåä Tributary: Hierarchical MMM for Music Marketing

**VOLTA Music Group Case Study**

*Can structure beat data size?*

---

This notebook walks through building and comparing Marketing Mix Models
for a music distribution company operating across 8 European markets.

**The Problem**: Poland and Sweden launched 6 months ago. Germany has 2 years
of data. How do we get reliable ROAS estimates for ALL markets?

**The Solution**: Hierarchical models with partial pooling.

---

## Table of Contents

1. [Setup & Data Generation](#1-setup)
2. [Exploring the Data](#2-data-exploration)
3. [Understanding Transforms](#3-transforms)
4. [Model Architectures](#4-models)
5. [Diagnostics](#5-diagnostics)
6. [ROAS Analysis](#6-roas)
7. [The Magic: Shrinkage](#7-shrinkage)
8. [Model Comparison](#8-comparison)
9. [Ground Truth Validation](#9-validation)
10. [When Hierarchical Fails](#10-failure-modes)
11. [Budget Allocation](#11-budget)
12. [Key Takeaways](#12-takeaways)


<a id="1-setup"></a>
## 1. Setup & Data Generation


In [None]:
# Core imports
import warnings

warnings.filterwarnings("ignore")

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns

# Set plotting style
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = (12, 6)
plt.rcParams["font.size"] = 11

# Random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

print("‚úÖ Core imports loaded")

In [None]:
# Tributary imports
from tributary.data.synthetic import (
    generate_synthetic_mmm_data,
    SyntheticDataConfig,
    summarize_dataset,
)
from tributary.data.schemas import MarketingDataFrame
from tributary.transforms import (
    geometric_adstock,
    delayed_adstock,
    MUSIC_CHANNEL_ADSTOCK_DEFAULTS,
    MUSIC_CHANNEL_SATURATION_DEFAULTS,
    plot_adstock_decay,
    plot_saturation_curve,
)
from tributary.models import (
    build_pooled_mmm,
    build_unpooled_mmm,
    build_hierarchical_mmm,
)
from tributary.evaluation import (
    run_mcmc_diagnostics,
    compute_roas_from_trace,
    format_roas_report,
    compute_shrinkage,
    roas_stability_comparison,
    compare_to_ground_truth,
    compute_optimal_allocation,
)

print("‚úÖ Tributary imports loaded")

### Generate Synthetic Data

We create the VOLTA scenario: 8 European markets, 6 music marketing channels,
with varying data lengths (Germany: 2 years, Poland: 6 months).


In [None]:
# Configure the VOLTA scenario
config = SyntheticDataConfig(
    countries=["DE", "FR", "UK", "NL", "ES", "IT", "PL", "SE"],
    channels=[
        "spotify_ads_spend",
        "meta_spend",
        "tiktok_spend",
        "youtube_spend",
        "radio_spend",
        "playlist_spend",
    ],
    weeks_per_country={
        "DE": 104,  # 2 years - mature market
        "UK": 104,  # 2 years
        "FR": 78,  # 1.5 years
        "NL": 78,  # 1.5 years
        "ES": 52,  # 1 year
        "IT": 52,  # 1 year
        "PL": 26,  # 6 months - SPARSE!
        "SE": 26,  # 6 months - SPARSE!
    },
    random_seed=RANDOM_SEED,
)

# Generate data with known ground truth
df, true_params = generate_synthetic_mmm_data(config, random_seed=RANDOM_SEED)

# Validate
MarketingDataFrame.validate(df)

print(f"‚úÖ Generated {len(df)} observations")
print(f"üìÖ Date range: {df['date'].min().date()} to {df['date'].max().date()}")

In [None]:
# Dataset summary
summarize_dataset(df)

<a id="2-data-exploration"></a>
## 2. Exploring the Data

The key challenge: **unequal data availability** across markets.


In [None]:
# Data availability visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart of observations per country
data_counts = df.groupby("country").size().sort_values()
colors = [
    "#e74c3c" if x < 40 else "#f39c12" if x < 80 else "#27ae60"
    for x in data_counts.values
]

ax1 = axes[0]
data_counts.plot(kind="barh", ax=ax1, color=colors)
ax1.axvline(52, color="gray", linestyle="--", alpha=0.7, label="1 year")
ax1.axvline(104, color="gray", linestyle=":", alpha=0.7, label="2 years")
ax1.set_xlabel("Weeks of Data")
ax1.set_title("Data Availability by Market\n(Red = Sparse, Green = Rich)")
ax1.legend()

# Revenue distribution by country
ax2 = axes[1]
df.boxplot(column="streaming_revenue", by="country", ax=ax2)
ax2.set_title("Streaming Revenue Distribution by Market")
ax2.set_xlabel("Country")
ax2.set_ylabel("Weekly Streaming Revenue (‚Ç¨)")
plt.suptitle("")  # Remove automatic title

plt.tight_layout()
plt.show()

print("\n‚ö†Ô∏è  Poland and Sweden have only 26 weeks of data!")
print("   Traditional approach: 'We need more data'")
print("   Our approach: 'We need better STRUCTURE'")

In [None]:
# Channel spend patterns over time
channel_cols = [c for c in df.columns if c.endswith("_spend")]

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for idx, channel in enumerate(channel_cols):
    ax = axes[idx]

    for country, style in [("DE", "-"), ("PL", "--")]:
        country_data = df[df["country"] == country].sort_values("date")
        ax.plot(
            country_data["date"],
            country_data[channel] / 1000,
            style,
            label=country,
            linewidth=1.5,
            alpha=0.8,
        )

    ax.set_title(channel.replace("_spend", "").replace("_", " ").title())
    ax.set_ylabel("Spend (‚Ç¨K)")
    ax.legend()
    ax.tick_params(axis="x", rotation=45)

plt.suptitle("Spend Patterns: Germany (solid) vs Poland (dashed)", y=1.02, fontsize=14)
plt.tight_layout()
plt.show()

<a id="3-transforms"></a>
## 3. Understanding Transforms

Marketing spend doesn't translate directly to streaming revenue. We need to model:

1. **Adstock (Carryover)**: Effects persist and decay over time
2. **Saturation (Diminishing Returns)**: More spend ‚â† proportionally more effect


In [None]:
# Adstock visualization
print("üìä ADSTOCK: How effects decay over time\n")

# Show default parameters
adstock_table = pd.DataFrame(
    [
        {
            "Channel": ch.replace("_spend", ""),
            "Œ± (decay)": params["alpha"],
            "Œ∏ (delay)": params.get("theta", 0),
            "Half-life": f"{np.log(0.5) / np.log(params['alpha']):.1f} weeks",
        }
        for ch, params in MUSIC_CHANNEL_ADSTOCK_DEFAULTS.items()
    ]
)
print(adstock_table.to_string(index=False))

# Plot decay curves
alphas = {
    "TikTok (Œ±=0.35, fast)": 0.35,
    "Meta (Œ±=0.55, moderate)": 0.55,
    "Radio (Œ±=0.70, slow)": 0.70,
}
fig = plot_adstock_decay(
    alphas, l_max=12, title="Adstock Decay: How Long Does the Effect Last?"
)
plt.show()

In [None]:
# Interactive adstock demo
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original spend (pulse)
spend_pulse = np.zeros(20)
spend_pulse[2] = 100000  # ‚Ç¨100K in week 2

axes[0].bar(range(20), spend_pulse / 1000, color="steelblue", alpha=0.7)
axes[0].set_title("1. Original Spend (single pulse)")
axes[0].set_xlabel("Week")
axes[0].set_ylabel("Spend (‚Ç¨K)")

# Adstock with different decay rates
for alpha, color, label in [
    (0.35, "green", "TikTok Œ±=0.35"),
    (0.55, "orange", "Meta Œ±=0.55"),
    (0.70, "red", "Radio Œ±=0.70"),
]:
    adstocked = geometric_adstock(spend_pulse, alpha=alpha, l_max=12, normalize=True)
    axes[1].plot(adstocked / 1000, color=color, label=label, linewidth=2)

axes[1].set_title("2. After Adstock (carryover)")
axes[1].set_xlabel("Week")
axes[1].set_ylabel("Effective Spend (‚Ç¨K)")
axes[1].legend()

# Delayed adstock (radio/playlist)
for theta, color, label in [
    (0, "blue", "Œ∏=0 (immediate)"),
    (2, "orange", "Œ∏=2 (slight delay)"),
    (4, "red", "Œ∏=4 (delayed peak)"),
]:
    delayed = delayed_adstock(
        spend_pulse, alpha=0.6, theta=theta, l_max=12, normalize=True
    )
    axes[2].plot(delayed / 1000, color=color, label=label, linewidth=2)

axes[2].set_title("3. Delayed Adstock (peak shift)")
axes[2].set_xlabel("Week")
axes[2].set_ylabel("Effective Spend (‚Ç¨K)")
axes[2].legend()

plt.tight_layout()
plt.show()

print("\nüí° Key Insight: TikTok effects fade fast. Radio effects persist for weeks.")

In [None]:
# Saturation visualization
print("\nüìä SATURATION: Diminishing returns\n")

# Show default parameters
sat_table = pd.DataFrame(
    [
        {
            "Channel": ch.replace("_spend", ""),
            "K (half-sat)": params["K"],
            "S (slope)": params["S"],
            "90% sat at": f"{params['K'] * (0.9 / 0.1) ** (1 / params['S']):.2f}",
        }
        for ch, params in MUSIC_CHANNEL_SATURATION_DEFAULTS.items()
    ]
)
print(sat_table.to_string(index=False))

# Plot saturation curves
sat_params = {
    "TikTok (K=0.30, steep)": {"K": 0.30, "S": 2.8},
    "Meta (K=0.50, moderate)": {"K": 0.50, "S": 2.2},
    "Radio (K=0.60, gradual)": {"K": 0.60, "S": 1.5},
}
fig = plot_saturation_curve(
    sat_params, title="Saturation: When Does More Spend Stop Helping?"
)
plt.show()

print(
    "\nüí° Key Insight: TikTok saturates quickly (viral or nothing). Radio has room to grow."
)

<a id="4-models"></a>
## 4. Model Architectures

We'll fit three models to compare:

| Model | Philosophy | Risk |
|-------|-----------|------|
| **Pooled** | All markets identical | Too rigid |
| **Unpooled** | All markets independent | Too noisy for sparse markets |
| **Hierarchical** | Similar but not identical | The sweet spot |


In [None]:
# Store traces for later comparison
traces = {}

### 4.1 Pooled Model

One set of parameters for all markets. Simple but assumes Germany and Poland
respond identically to Spotify ads.


In [None]:
print("üèóÔ∏è  Building POOLED model...")
print("   Philosophy: 'All markets are the same'\n")

pooled_model = build_pooled_mmm(df, channel_cols)

# Quick model summary
print(
    f"   Parameters: ~{sum(v.size.eval() if hasattr(v.size, 'eval') else 1 for v in pooled_model.free_RVs)} free variables"
)

In [None]:
%%time
print("üé≤ Sampling pooled model...")

with pooled_model:
    traces["pooled"] = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        target_accept=0.9,
        random_seed=RANDOM_SEED,
        return_inferencedata=True,
        progressbar=True,
    )

print("\n‚úÖ Pooled model complete!")

### 4.2 Unpooled Model

Separate parameters for each market. Maximum flexibility, but Poland's estimates
will be VERY noisy with only 26 data points.


In [None]:
print("üèóÔ∏è  Building UNPOOLED model...")
print("   Philosophy: 'Every market is unique'\n")

unpooled_model = build_unpooled_mmm(df, channel_cols)

In [None]:
%%time
print("üé≤ Sampling unpooled model...")

with unpooled_model:
    traces["unpooled"] = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        target_accept=0.9,
        random_seed=RANDOM_SEED,
        return_inferencedata=True,
        progressbar=True,
    )

print("\n‚úÖ Unpooled model complete!")

### 4.3 Hierarchical Model ‚≠ê

The star of the show. Markets are "exchangeable" ‚Äî similar enough to share
information, but allowed to differ where data supports it.


In [None]:
print("üèóÔ∏è  Building HIERARCHICAL model...")
print("   Philosophy: 'Markets are similar but not identical'\n")

hierarchical_model = build_hierarchical_mmm(df, channel_cols)

In [None]:
%%time
print("üé≤ Sampling hierarchical model...")

with hierarchical_model:
    traces["hierarchical"] = pm.sample(
        draws=1000,
        tune=1000,
        chains=4,
        target_accept=0.95,  # Higher for hierarchical
        random_seed=RANDOM_SEED,
        return_inferencedata=True,
        progressbar=True,
    )

print("\n‚úÖ Hierarchical model complete!")

<a id="5-diagnostics"></a>
## 5. Diagnostics

Before trusting any ROAS estimates, we verify MCMC convergence.


In [None]:
# Run diagnostics for all models
print("üîç MCMC DIAGNOSTICS SUMMARY")
print("=" * 60)

for name, trace in traces.items():
    report = run_mcmc_diagnostics(trace)
    status_emoji = {"good": "‚úÖ", "warning": "‚ö†Ô∏è", "bad": "‚ùå"}[report.overall_status]

    print(f"\n{status_emoji} {name.upper()}")
    print(f"   Status: {report.overall_status}")
    print(f"   Divergences: {report.divergences}")
    print(f"   Max R-hat: {report.rhat_summary['rhat'].max():.4f}")
    print(f"   Min ESS (bulk): {report.ess_summary['ess_bulk'].min():.0f}")

    if report.problematic_params:
        print(f"   ‚ö†Ô∏è  Issues: {report.problematic_params[:2]}")

print("\n" + "=" * 60)

In [None]:
# Visual diagnostics for hierarchical model
print("\nüìä Hierarchical Model: Trace Plots")

az.plot_trace(
    traces["hierarchical"],
    var_names=["beta_mu", "beta_sigma"],
    figsize=(12, 8),
)
plt.suptitle("Hierarchical Model: Global Parameters", y=1.02)
plt.tight_layout()
plt.show()

<a id="6-roas"></a>
## 6. ROAS Analysis

The key business question: "How much streaming revenue per euro spent?"


In [None]:
# Compute ROAS for hierarchical model
hier_roas = compute_roas_from_trace(traces["hierarchical"], df, channel_cols)

print(format_roas_report(hier_roas))

In [None]:
# Visualize ROAS by channel and country
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for idx, channel in enumerate(channel_cols):
    ax = axes[idx]

    countries = hier_roas.mean.index.tolist()
    means = hier_roas.mean[channel].values
    lows = hier_roas.hdi_low[channel].values
    highs = hier_roas.hdi_high[channel].values

    # Color by data availability
    colors = ["#e74c3c" if c in ["PL", "SE"] else "#3498db" for c in countries]

    y_pos = np.arange(len(countries))
    ax.barh(
        y_pos,
        means,
        xerr=[means - lows, highs - means],
        color=colors,
        alpha=0.7,
        capsize=3,
    )

    ax.set_yticks(y_pos)
    ax.set_yticklabels(countries)
    ax.set_xlabel("ROAS")
    ax.set_title(channel.replace("_spend", "").title())
    ax.axvline(means.mean(), color="green", linestyle="--", alpha=0.7)

plt.suptitle("ROAS by Channel √ó Country (Red = Sparse Markets)", y=1.02, fontsize=14)
plt.tight_layout()
plt.show()

print("üí° Notice: Poland & Sweden (red) have similar uncertainty to rich markets!")
print("   That's partial pooling in action.")

<a id="7-shrinkage"></a>
## 7. The Magic: Shrinkage ‚≠ê

This is the key insight of hierarchical models.

**Shrinkage** measures how much each country's estimate was "pulled" toward
the global mean:
- High shrinkage (‚Üí1): Country borrowed heavily from the group
- Low shrinkage (‚Üí0): Country estimate driven by its own data


In [None]:
# Compute shrinkage
shrinkage = compute_shrinkage(traces["hierarchical"])

# Heatmap
fig, ax = plt.subplots(figsize=(12, 6))

# Sort by data availability
data_counts = df.groupby("country").size()
country_order = data_counts.sort_values().index.tolist()
shrinkage_sorted = shrinkage.loc[country_order]

sns.heatmap(
    shrinkage_sorted,
    annot=True,
    fmt=".2f",
    cmap="RdYlGn",
    center=0.5,
    vmin=0,
    vmax=1,
    ax=ax,
    cbar_kws={"label": "Shrinkage (0=own data, 1=group mean)"},
    xticklabels=[c.replace("_spend", "") for c in shrinkage_sorted.columns],
)

# Add data count annotations
for i, country in enumerate(country_order):
    ax.text(
        -0.7,
        i + 0.5,
        f"n={data_counts[country]}",
        ha="right",
        va="center",
        fontsize=10,
        style="italic",
    )

ax.set_title("Shrinkage: How Much Did Each Market Borrow From the Group?", fontsize=14)
ax.set_ylabel("Country (sorted by data availability)")
plt.tight_layout()
plt.show()

print("\nüìä READING THE SHRINKAGE PLOT:")
print("   üü¢ Green (high): Market borrowed heavily from group")
print("   üî¥ Red (low): Market estimate driven by own data")
print("")
print("   ‚Üí Poland & Sweden (top) show more green = more borrowing")
print("   ‚Üí Germany & UK (bottom) show more red = own estimates")

In [None]:
# Quantify the shrinkage effect
print("\nüìà SHRINKAGE BY MARKET\n")

shrinkage_summary = pd.DataFrame(
    {
        "Country": country_order,
        "Weeks of Data": [data_counts[c] for c in country_order],
        "Mean Shrinkage": [shrinkage.loc[c].mean() for c in country_order],
    }
)
print(shrinkage_summary.to_string(index=False))

print("\nüí° Key Pattern: Less data ‚Üí More shrinkage ‚Üí More borrowing from group")

<a id="8-comparison"></a>
## 8. Model Comparison

Does hierarchical structure actually help?


In [None]:
# Compare ROAS stability across models
stability = roas_stability_comparison(traces, channel_cols)

print("üìä ROAS STABILITY COMPARISON")
print("(Cross-country standard deviation ‚Äî lower = more stable)\n")
print(stability.round(3))

In [None]:
# Visualize stability comparison
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(channel_cols))
width = 0.25

for i, model in enumerate(["pooled", "unpooled", "hierarchical"]):
    if model in stability.columns:
        offset = (i - 1) * width
        bars = ax.bar(
            x + offset, stability[model], width, label=model.title(), alpha=0.8
        )

ax.set_xticks(x)
ax.set_xticklabels(
    [c.replace("_spend", "") for c in channel_cols], rotation=45, ha="right"
)
ax.set_ylabel("Cross-Country Std Dev")
ax.set_title("ROAS Stability: Hierarchical Provides the Best Balance")
ax.legend()

plt.tight_layout()
plt.show()

print("\nüí° Interpretation:")
print("   ‚Ä¢ Pooled: Zero variation (too rigid ‚Äî ignores real differences)")
print("   ‚Ä¢ Unpooled: High variation (too noisy ‚Äî sparse markets are unstable)")
print("   ‚Ä¢ Hierarchical: Balanced (stable where needed, flexible where supported)")

In [None]:
# LOO-CV comparison
print("\nüìä MODEL COMPARISON (LOO-CV)")
print("=" * 60)

try:
    comparison = az.compare(traces, ic="loo", scale="log")
    print(comparison)

    print(f"\nüèÜ Best model: {comparison.index[0]}")
except Exception as e:
    print(f"Could not compute LOO-CV: {e}")

<a id="9-validation"></a>
## 9. Ground Truth Validation

Since we generated the data, we can check how well we recovered the true parameters!


In [None]:
# Compare recovered parameters to ground truth
print("üìä GROUND TRUTH COMPARISON")
print("=" * 60)

print("\n[Global Channel Effects: Œ≤_Œº]\n")

comparison_table = pd.DataFrame(
    {
        "Channel": [c.replace("_spend", "") for c in channel_cols],
        "True": true_params.beta_mu,
        "Recovered": traces["hierarchical"]
        .posterior["beta_mu"]
        .mean(dim=["chain", "draw"])
        .values,
    }
)
comparison_table["Error"] = comparison_table["Recovered"] - comparison_table["True"]
comparison_table["Abs Error"] = comparison_table["Error"].abs()

print(comparison_table.to_string(index=False))
print(f"\nMean Absolute Error: {comparison_table['Abs Error'].mean():.3f}")

In [None]:
# Visual comparison
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(channel_cols))
width = 0.35

ax.bar(
    x - width / 2, true_params.beta_mu, width, label="True", color="green", alpha=0.7
)

recovered_mean = (
    traces["hierarchical"].posterior["beta_mu"].mean(dim=["chain", "draw"]).values
)
recovered_std = (
    traces["hierarchical"].posterior["beta_mu"].std(dim=["chain", "draw"]).values
)
ax.bar(
    x + width / 2,
    recovered_mean,
    width,
    yerr=recovered_std,
    label="Recovered (¬±1œÉ)",
    color="blue",
    alpha=0.7,
    capsize=3,
)

ax.set_xticks(x)
ax.set_xticklabels(
    [c.replace("_spend", "") for c in channel_cols], rotation=45, ha="right"
)
ax.set_ylabel("Œ≤_Œº (Global Channel Effect)")
ax.set_title("Ground Truth vs Hierarchical Model Recovery")
ax.legend()

plt.tight_layout()
plt.show()

In [None]:
# Check HDI coverage
print("\nüìä 94% HDI COVERAGE CHECK\n")

gt_comparison = compare_to_ground_truth(traces["hierarchical"], true_params, "beta_mu")
print(
    gt_comparison[
        ["channel", "true_value", "recovered_mean", "hdi_low", "hdi_high", "covered"]
    ].to_string(index=False)
)

coverage_rate = gt_comparison["covered"].mean()
print(f"\nCoverage: {coverage_rate:.0%} (target: 94%)")

if coverage_rate >= 0.9:
    print("‚úÖ Model is well-calibrated!")
else:
    print("‚ö†Ô∏è  Coverage is low ‚Äî model may be overconfident")

<a id="10-failure-modes"></a>
## 10. When Hierarchical Fails

Partial pooling isn't magic. Let's understand when it breaks.


In [None]:
print("üö® WHEN HIERARCHICAL MODELS FAIL")
print("=" * 60)

print("""
Partial pooling assumes markets are "exchangeable" ‚Äî similar enough that
borrowing information makes sense. This fails when:

1. MARKETS ARE GENUINELY VERY DIFFERENT
   - Different music cultures (K-pop market vs Classical market)
   - Different platform dominance (TikTok huge in one country, banned in another)
   - Different regulations (radio quotas, streaming taxes)
   
   Symptom: High between-group variance estimate (Œ≤_sigma)
   Solution: Reconsider grouping, add covariates, or use separate models

2. YOU HAVE ALMOST NO DATA ANYWHERE
   - Can't borrow from an empty pool
   - Need at least some markets with decent data
   
   Symptom: Very wide posteriors even with pooling
   Solution: Use stronger informative priors, get more data

3. THE HIERARCHY IS WRONG
   - Pooling by country when you should pool by channel type
   - Pooling all channels when some behave very differently
   
   Symptom: Poor predictive performance, unexpected shrinkage
   Solution: Restructure the hierarchy, try nested groupings

4. NON-EXCHANGEABLE TEMPORAL PATTERNS
   - Markets at different maturity stages
   - One market had a major event (concert tour, scandal)
   
   Symptom: Systematic residual patterns by country
   Solution: Add time-varying effects, market-specific trends
""")

In [None]:
# Check for potential issues in our model
print("\nüìä CHECKING FOR FAILURE MODE INDICATORS\n")

# 1. Between-group variance
beta_sigma = (
    traces["hierarchical"].posterior["beta_sigma"].mean(dim=["chain", "draw"]).values
)
print("Between-Country Variance (Œ≤_sigma):")
for ch, sigma in zip(channel_cols, beta_sigma):
    flag = " ‚ö†Ô∏è HIGH" if sigma > 0.3 else ""
    print(f"   {ch.replace('_spend', ''):15s}: {sigma:.3f}{flag}")

# 2. Countries with unusual shrinkage
print("\nCountries with Unusual Shrinkage:")
mean_shrinkage = shrinkage.mean(axis=1)
for country in mean_shrinkage.index:
    val = mean_shrinkage[country]
    if val < 0.2 or val > 0.8:
        flag = (
            "very low (own data dominates)"
            if val < 0.2
            else "very high (heavily pooled)"
        )
        print(f"   {country}: {val:.2f} ‚Äî {flag}")

<a id="11-budget"></a>
## 11. Budget Allocation

The ultimate business deliverable: where should VOLTA invest its ‚Ç¨500K quarterly budget?


In [None]:
# Compute optimal allocation
QUARTERLY_BUDGET = 500_000  # ‚Ç¨500K

allocation = compute_optimal_allocation(hier_roas, QUARTERLY_BUDGET)

print(f"üìä RECOMMENDED BUDGET ALLOCATION (‚Ç¨{QUARTERLY_BUDGET:,})")
print("=" * 60)
print("\nBased on hierarchical model ROAS estimates:\n")

# Format nicely
allocation_display = allocation.copy()
allocation_display.columns = [
    c.replace("_spend", "") for c in allocation_display.columns
]
allocation_display["TOTAL"] = allocation_display.sum(axis=1)

# Add grand total row
totals = allocation_display.sum()
totals.name = "TOTAL"
allocation_display = pd.concat([allocation_display, totals.to_frame().T])

print(allocation_display.round(0).astype(int).to_string())

In [None]:
# Visualize allocation
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# By channel
ax1 = axes[0]
channel_totals = allocation.sum()
channel_totals.index = [c.replace("_spend", "") for c in channel_totals.index]
channel_totals.sort_values().plot(kind="barh", ax=ax1, color="steelblue", alpha=0.8)
ax1.set_xlabel("Budget (‚Ç¨)")
ax1.set_title("Total Allocation by Channel")

# By country
ax2 = axes[1]
country_totals = allocation.sum(axis=1)
country_totals.sort_values().plot(kind="barh", ax=ax2, color="coral", alpha=0.8)
ax2.set_xlabel("Budget (‚Ç¨)")
ax2.set_title("Total Allocation by Market")

plt.suptitle(
    f"Recommended Budget Allocation: ‚Ç¨{QUARTERLY_BUDGET:,}", y=1.02, fontsize=14
)
plt.tight_layout()
plt.show()

In [None]:
print("üíæ Saving traces...")

from pathlib import Path

results_dir = Path("../results")
results_dir.mkdir(exist_ok=True)

for name, trace in traces.items():
    path = results_dir / f"{name}_trace.nc"
    trace.to_netcdf(path)
    print(f"   ‚úÖ Saved {path}")

print("\nüéâ Notebook complete!")