# Nepenthe Processing Pipeline

This notebook processes organization data through the following steps:

1. **Add Subregion**: Extract country from headquarters location and map to UN M49 subregions
2. **Create Subgroups**: Filter organizations into different ML consultancy categories
3. **Summary Table**: Generate comparison table across all subgroups

**Input:** `final_results_main_orgs - final_results_main_orgs.csv`  
**Output:** Comparison table with organization counts, employee totals, ML estimates, and regional breakdowns

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Patch
import matplotlib.ticker as mticker
import warnings
import re
from matplotlib.lines import Line2D
from matplotlib.patches import FancyBboxPatch
from matplotlib.colors import LinearSegmentedColormap
warnings.filterwarnings('ignore')

# =============================================================================
# DESIGN SYSTEM — Academic/Scientific Style
# =============================================================================
# Inspired by Nature/Science journals: clean, sophisticated, publication-ready

# Typography Configuration
FONT_FAMILY = 'Helvetica Neue'  # Falls back to Helvetica, then Arial
FONT_SIZES = {
    'title': 14,
    'subtitle': 11,
    'axis_label': 11,
    'tick_label': 9,
    'legend': 9,
    'annotation': 8,
    'org_label': 7,
}

# Set matplotlib defaults
plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Helvetica Neue', 'Helvetica', 'Arial', 'DejaVu Sans'],
    'font.size': FONT_SIZES['tick_label'],
    'axes.titlesize': FONT_SIZES['title'],
    'axes.labelsize': FONT_SIZES['axis_label'],
    'xtick.labelsize': FONT_SIZES['tick_label'],
    'ytick.labelsize': FONT_SIZES['tick_label'],
    'legend.fontsize': FONT_SIZES['legend'],
    'figure.titlesize': FONT_SIZES['title'],
    'axes.titleweight': 'medium',
    'axes.labelweight': 'regular',
    'axes.linewidth': 0.8,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.alpha': 0.25,
    'grid.linewidth': 0.5,
    'grid.linestyle': '-',
    'figure.facecolor': 'white',
    'axes.facecolor': 'white',
    'savefig.facecolor': 'white',
    'savefig.dpi': 300,
})

# =============================================================================
# COLOR PALETTE — Refined Academic
# =============================================================================
# Muted, sophisticated colors with good contrast and colorblind accessibility

# Primary palette for categorical data (6 colors)
PALETTE = {
    'blue':     '#3C5488',  # Deep slate blue
    'red':      '#DC3220',  # Vermillion (colorblind-safe red)
    'green':    '#009988',  # Teal/cyan-green
    'gold':     '#E68613',  # Amber/ochre  
    'purple':   '#7B4B94',  # Muted violet
    'gray':     '#868686',  # Neutral gray
}

# Sequential list for scatter plots
PALETTE_LIST = [
    PALETTE['blue'],
    PALETTE['red'], 
    PALETTE['green'],
    PALETTE['gold'],
    PALETTE['purple'],
    PALETTE['gray'],
]

# Specific role colors
COLORS = {
    # Confidence intervals
    'ci_pure_probit': '#2C6E49',      # Forest green
    'ci_adjusted_synthetic': '#7B4B94', # Purple
    
    # Confidence category colors (landscape plot)
    'probable':     '#2C6E49',   # Forest green (CI excludes zero)
    'possible':     '#E68613',   # Amber (central estimate positive)
    'nonzero':      '#7EB5D6',   # Light blue (upper bound positive)
    'not_detected': '#868686',   # Gray (all zeros)
    
    # Geographic/thematic
    'primary':    '#3C5488',   # Main brand blue
    'secondary':  '#009988',   # Accent teal
    'muted':      '#868686',   # Muted elements
    'background': '#F5F5F5',   # Light background
    'gridline':   '#E0E0E0',   # Grid color
}

# Landscape plot configuration
LANDSCAPE_PALETTE = {
    "Probable":     COLORS['probable'],
    "Possible":     COLORS['possible'],
    "Non-zero":     COLORS['nonzero'],
    "Not Detected": COLORS['not_detected'],
}

LANDSCAPE_MARKERS = {
    "Probable":     "D",   # Diamond
    "Possible":     "s",   # Square
    "Non-zero":     "^",   # Triangle up
    "Not Detected": "o",   # Circle
}

# Estimator colors and markers (for Step 4 plot)
ESTIMATOR_STYLES = {
    # Keyword filters — circles with muted tones
    'filter_broad_yes':           {'color': PALETTE['gray'],   'marker': 'o', 'size': 24},
    'filter_strict_no':           {'color': PALETTE['gray'],   'marker': 'v', 'size': 24},
    'filter_broad_yes_strict_no': {'color': '#5A5A5A',         'marker': 's', 'size': 24},
    # LLM estimates — distinctive colors, larger
    'claude_total_accepted':      {'color': PALETTE['blue'],   'marker': 'D', 'size': 36},
    'gpt5_total_accepted':        {'color': PALETTE['green'],  'marker': '^', 'size': 36},
    'gemini_total_accepted':      {'color': PALETTE['gold'],   'marker': 'P', 'size': 40},
}

# Estimator labels (Step 4)
ESTIMATOR_LABELS = {
    'filter_broad_yes': 'Keyword: Broad Yes',
    'filter_strict_no': 'Keyword: Strict No', 
    'filter_broad_yes_strict_no': 'Keyword: Broad+Strict',
    'claude_total_accepted': 'Claude (sonnet-4)',
    'gpt5_total_accepted': 'GPT-5-mini',
    'gemini_total_accepted': 'Gemini 2.5 Flash',
}

# =============================================================================
# CONFIGURATION
# =============================================================================

SAVE_OUTPUTS = True  # Set to True to save output CSVs
DATA_DIR = Path('..') if Path('../2026-01-28_final_results_main_orgs.csv').exists() else Path('.')

def assign_confidence_category(q10, q50, q90):
    """
    Assign organization to confidence category based on statistical estimates.
    
    Categories (mutually exclusive, checked in order):
    - Probable: q10 > 0 (80% CI excludes zero)
    - Possible: q50 > 0, q10 = 0 (central estimate positive but uncertain)
    - Non-zero: q90 > 0, q50 = 0 (upper bound positive only)
    - Not Detected: all zeros (no ML signal)
    """
    if pd.isna(q10) or pd.isna(q50) or pd.isna(q90):
        return "Not Detected"
    if q10 > 0:
        return "Probable"
    if q50 > 0:
        return "Possible"
    if q90 > 0:
        return "Non-zero"
    return "Not Detected"


# =============================================================================
# HELPER FUNCTIONS — Figure Styling
# =============================================================================

def format_log_axis(ax, axis='y', limits=(1, 10000)):
    """Format log-scale axis with clean tick labels."""
    if axis == 'y':
        ax.set_yscale('log')
        ax.set_ylim(limits)
        # Clean integer ticks
        ticks = [t for t in [1, 10, 100, 1000, 10000] if limits[0] <= t <= limits[1]]
        ax.set_yticks(ticks)
        ax.set_yticklabels([f'{t:,}' if t >= 1000 else str(t) for t in ticks])
        ax.yaxis.set_minor_locator(mticker.NullLocator())
    else:
        ax.set_xscale('log')
        ax.set_xlim(limits)
        ticks = [t for t in [1, 10, 100, 1000, 10000] if limits[0] <= t <= limits[1]]
        ax.set_xticks(ticks)
        ax.set_xticklabels([f'{t:,}' if t >= 1000 else str(t) for t in ticks])
        ax.xaxis.set_minor_locator(mticker.NullLocator())

---
## Step 1: Add Subregion Column

Extract country information from headquarters location strings and map to UN M49 subregions.

In [None]:
# Country name variations mapping
COUNTRY_VARIATIONS = {
    'USA': 'United States',
    'US': 'United States',
    'UK': 'United Kingdom',
    'United Kingdom of Great Britain and Northern Ireland': 'United Kingdom',
    'The Netherlands': 'Netherlands',
    'NA - South Africa': 'South Africa',
    'NA - Vietnam': 'Vietnam',
    'NA - Uruguay': 'Uruguay',
    'Russian Federation': 'Russia',
}

In [None]:
# UN M49 subregion mapping
COUNTRY_TO_SUBREGION = {
    # Northern America
    'United States': 'Northern America',
    'Canada': 'Northern America',
    'Bermuda': 'Northern America',
    'Greenland': 'Northern America',
    'Saint Pierre and Miquelon': 'Northern America',

    # Central America
    'Belize': 'Central America',
    'Costa Rica': 'Central America',
    'El Salvador': 'Central America',
    'Guatemala': 'Central America',
    'Honduras': 'Central America',
    'Mexico': 'Central America',
    'Nicaragua': 'Central America',
    'Panama': 'Central America',

    # South America
    'Argentina': 'South America',
    'Bolivia': 'South America',
    'Brazil': 'South America',
    'Chile': 'South America',
    'Colombia': 'South America',
    'Ecuador': 'South America',
    'French Guiana': 'South America',
    'Guyana': 'South America',
    'Paraguay': 'South America',
    'Peru': 'South America',
    'Suriname': 'South America',
    'Uruguay': 'South America',
    'Venezuela': 'South America',

    # Western Europe
    'Austria': 'Western Europe',
    'Belgium': 'Western Europe',
    'France': 'Western Europe',
    'Germany': 'Western Europe',
    'Liechtenstein': 'Western Europe',
    'Luxembourg': 'Western Europe',
    'Monaco': 'Western Europe',
    'Netherlands': 'Western Europe',
    'Switzerland': 'Western Europe',

    # Northern Europe
    'Denmark': 'Northern Europe',
    'Estonia': 'Northern Europe',
    'Finland': 'Northern Europe',
    'Iceland': 'Northern Europe',
    'Ireland': 'Northern Europe',
    'Latvia': 'Northern Europe',
    'Lithuania': 'Northern Europe',
    'Norway': 'Northern Europe',
    'Sweden': 'Northern Europe',
    'United Kingdom': 'Northern Europe',

    # Eastern Europe
    'Belarus': 'Eastern Europe',
    'Bulgaria': 'Eastern Europe',
    'Czech Republic': 'Eastern Europe',
    'Czechia': 'Eastern Europe',
    'Hungary': 'Eastern Europe',
    'Poland': 'Eastern Europe',
    'Moldova': 'Eastern Europe',
    'Romania': 'Eastern Europe',
    'Russia': 'Eastern Europe',
    'Russian Federation': 'Eastern Europe',
    'Slovakia': 'Eastern Europe',
    'Ukraine': 'Eastern Europe',

    # Southern Europe
    'Albania': 'Southern Europe',
    'Andorra': 'Southern Europe',
    'Bosnia and Herzegovina': 'Southern Europe',
    'Croatia': 'Southern Europe',
    'Gibraltar': 'Southern Europe',
    'Greece': 'Southern Europe',
    'Italy': 'Southern Europe',
    'Malta': 'Southern Europe',
    'Montenegro': 'Southern Europe',
    'North Macedonia': 'Southern Europe',
    'Portugal': 'Southern Europe',
    'San Marino': 'Southern Europe',
    'Serbia': 'Southern Europe',
    'Slovenia': 'Southern Europe',
    'Spain': 'Southern Europe',
    'Vatican City': 'Southern Europe',

    # Eastern Asia
    'China': 'Eastern Asia',
    'Hong Kong': 'Eastern Asia',
    'Japan': 'Eastern Asia',
    'Macao': 'Eastern Asia',
    'Mongolia': 'Eastern Asia',
    'North Korea': 'Eastern Asia',
    'South Korea': 'Eastern Asia',
    'Taiwan': 'Eastern Asia',

    # South-Eastern Asia
    'Brunei': 'South-Eastern Asia',
    'Cambodia': 'South-Eastern Asia',
    'Indonesia': 'South-Eastern Asia',
    'Laos': 'South-Eastern Asia',
    'Malaysia': 'South-Eastern Asia',
    'Myanmar': 'South-Eastern Asia',
    'Philippines': 'South-Eastern Asia',
    'Singapore': 'South-Eastern Asia',
    'Thailand': 'South-Eastern Asia',
    'Timor-Leste': 'South-Eastern Asia',
    'Vietnam': 'South-Eastern Asia',

    # Southern Asia
    'Afghanistan': 'Southern Asia',
    'Bangladesh': 'Southern Asia',
    'Bhutan': 'Southern Asia',
    'India': 'Southern Asia',
    'Iran': 'Southern Asia',
    'Maldives': 'Southern Asia',
    'Nepal': 'Southern Asia',
    'Pakistan': 'Southern Asia',
    'Sri Lanka': 'Southern Asia',

    # Western Asia
    'Armenia': 'Western Asia',
    'Azerbaijan': 'Western Asia',
    'Bahrain': 'Western Asia',
    'Cyprus': 'Western Asia',
    'Georgia': 'Western Asia',
    'Iraq': 'Western Asia',
    'Israel': 'Western Asia',
    'Jordan': 'Western Asia',
    'Kuwait': 'Western Asia',
    'Lebanon': 'Western Asia',
    'Oman': 'Western Asia',
    'Qatar': 'Western Asia',
    'Saudi Arabia': 'Western Asia',
    'Syria': 'Western Asia',
    'Turkey': 'Western Asia',
    'United Arab Emirates': 'Western Asia',
    'Yemen': 'Western Asia',

    # Central Asia
    'Kazakhstan': 'Central Asia',
    'Kyrgyzstan': 'Central Asia',
    'Tajikistan': 'Central Asia',
    'Turkmenistan': 'Central Asia',
    'Uzbekistan': 'Central Asia',

    # Northern Africa
    'Algeria': 'Northern Africa',
    'Egypt': 'Northern Africa',
    'Libya': 'Northern Africa',
    'Morocco': 'Northern Africa',
    'Sudan': 'Northern Africa',
    'Tunisia': 'Northern Africa',
    'Western Sahara': 'Northern Africa',

    # Eastern Africa
    'Burundi': 'Eastern Africa',
    'Comoros': 'Eastern Africa',
    'Djibouti': 'Eastern Africa',
    'Eritrea': 'Eastern Africa',
    'Ethiopia': 'Eastern Africa',
    'Kenya': 'Eastern Africa',
    'Madagascar': 'Eastern Africa',
    'Malawi': 'Eastern Africa',
    'Mauritius': 'Eastern Africa',
    'Mozambique': 'Eastern Africa',
    'Rwanda': 'Eastern Africa',
    'Seychelles': 'Eastern Africa',
    'Somalia': 'Eastern Africa',
    'South Sudan': 'Eastern Africa',
    'Tanzania': 'Eastern Africa',
    'Uganda': 'Eastern Africa',
    'Zambia': 'Eastern Africa',
    'Zimbabwe': 'Eastern Africa',

    # Southern Africa
    'Botswana': 'Southern Africa',
    'Eswatini': 'Southern Africa',
    'Lesotho': 'Southern Africa',
    'Namibia': 'Southern Africa',
    'South Africa': 'Southern Africa',

    # Western Africa
    'Benin': 'Western Africa',
    'Burkina Faso': 'Western Africa',
    'Cabo Verde': 'Western Africa',
    "Côte d'Ivoire": 'Western Africa',
    'Gambia': 'Western Africa',
    'Ghana': 'Western Africa',
    'Guinea': 'Western Africa',
    'Guinea-Bissau': 'Western Africa',
    'Liberia': 'Western Africa',
    'Mali': 'Western Africa',
    'Mauritania': 'Western Africa',
    'Niger': 'Western Africa',
    'Nigeria': 'Western Africa',
    'Senegal': 'Western Africa',
    'Sierra Leone': 'Western Africa',
    'Togo': 'Western Africa',

    # Middle Africa
    'Angola': 'Middle Africa',
    'Cameroon': 'Middle Africa',
    'Central African Republic': 'Middle Africa',
    'Chad': 'Middle Africa',
    'Congo': 'Middle Africa',
    'Democratic Republic of the Congo': 'Middle Africa',
    'Equatorial Guinea': 'Middle Africa',
    'Gabon': 'Middle Africa',
    'São Tomé and Príncipe': 'Middle Africa',

    # Australia and New Zealand
    'Australia': 'Australia and New Zealand',
    'New Zealand': 'Australia and New Zealand',

    # Caribbean
    'Antigua and Barbuda': 'Caribbean',
    'Bahamas': 'Caribbean',
    'Barbados': 'Caribbean',
    'Cuba': 'Caribbean',
    'Dominica': 'Caribbean',
    'Dominican Republic': 'Caribbean',
    'Grenada': 'Caribbean',
    'Haiti': 'Caribbean',
    'Jamaica': 'Caribbean',
    'Saint Kitts and Nevis': 'Caribbean',
    'Saint Lucia': 'Caribbean',
    'Saint Vincent and the Grenadines': 'Caribbean',
    'Trinidad and Tobago': 'Caribbean',

    # Melanesia
    'Fiji': 'Melanesia',
    'New Caledonia': 'Melanesia',
    'Papua New Guinea': 'Melanesia',
    'Solomon Islands': 'Melanesia',
    'Vanuatu': 'Melanesia',

    # Micronesia
    'Guam': 'Micronesia',
    'Kiribati': 'Micronesia',
    'Marshall Islands': 'Micronesia',
    'Micronesia': 'Micronesia',
    'Nauru': 'Micronesia',
    'Northern Mariana Islands': 'Micronesia',
    'Palau': 'Micronesia',

    # Polynesia
    'American Samoa': 'Polynesia',
    'Cook Islands': 'Polynesia',
    'French Polynesia': 'Polynesia',
    'Niue': 'Polynesia',
    'Samoa': 'Polynesia',
    'Tonga': 'Polynesia',
    'Tuvalu': 'Polynesia',
}

In [None]:
def extract_country_from_location(location):
    """Extract country from headquarters location string."""
    if pd.isna(location) or location == "":
        return ""
    
    # Split by comma and take the last part as country
    parts = [part.strip() for part in str(location).split(',')]
    country = parts[-1] if parts else ""
    
    return COUNTRY_VARIATIONS.get(country, country)


def map_country_to_subregion(country):
    """Map country to UN M49 subregion."""
    return COUNTRY_TO_SUBREGION.get(country, 'Unknown')

In [None]:
# Load data
input_file = DATA_DIR / '2026-01-28_final_results_main_orgs.csv'
df = pd.read_csv(input_file)
print(f"Loaded {len(df)} rows from {input_file.name}")

In [None]:
# Verify required column exists
if 'headquarters_location' not in df.columns:
    raise ValueError(f"'headquarters_location' column not found. Available: {list(df.columns)}")

# Extract countries and map to subregions
df['Country'] = df['headquarters_location'].apply(extract_country_from_location)
df['Subregion'] = df['Country'].apply(map_country_to_subregion)

print("Added Country and Subregion columns")

In [None]:
# Summary statistics
total_rows = len(df)
unknown_count = len(df[df['Subregion'] == 'Unknown'])
success_rate = (total_rows - unknown_count) / total_rows * 100

print(f"Subregion Mapping Summary")
print(f"="*40)
print(f"Total locations processed: {total_rows}")
print(f"Successfully mapped: {total_rows - unknown_count}")
print(f"Unknown mappings: {unknown_count}")
print(f"Success rate: {success_rate:.1f}%")

In [None]:
# Subregion distribution
print("Subregion Distribution")
print("="*40)
subregion_counts = df['Subregion'].value_counts()
for subregion, count in subregion_counts.items():
    print(f"  {subregion}: {count}")

In [None]:
# Show unknown countries if any
if unknown_count > 0:
    print("Unknown countries found:")
    unknown_countries = df[df['Subregion'] == 'Unknown']['Country'].value_counts()
    for country, count in unknown_countries.items():
        print(f"  {country}: {count}")

In [None]:
# Preview results
print("Sample of processed data:")
df[['headquarters_location', 'Country', 'Subregion']].head(10)

---
## Step 2: Create Subgroups

Filter organizations into confidence categories based on ML estimate uncertainty:

| Category | Criteria | Description |
|----------|----------|-------------|
| **All** | - | All 403 organizations |
| **Probable** | q10 > 0 | 80% CI excludes zero - confident ML presence |
| **Possible** | q50 > 0, q10 = 0 | Central estimate positive but CI includes zero |
| **Non-zero** | q90 > 0, q50 = 0 | Upper bound positive but central estimate zero |
| **Not Detected** | q90 = q50 = q10 = 0 | All estimates zero - no ML signal |

Where q10/q50/q90 use pure probit estimates when available, otherwise adjusted synthetic estimates.

In [None]:
# Define key columns
ML_STAFF_COL = 'adjusted_synthetic_q50'
HEADCOUNT_COL = 'total_headcount' 

In [None]:
# Verify required columns exist
required_cols = [ML_STAFF_COL, HEADCOUNT_COL]
missing_cols = [c for c in required_cols if c not in df.columns]

if missing_cols:
    print(f"Warning: Missing columns: {missing_cols}")
else:
    print("All required columns present")

In [None]:
# Calculate ML estimate using fallback: q50 if available, else adjusted_synthetic_q50
q50_pure = pd.to_numeric(df['q50'], errors='coerce') if 'q50' in df.columns else pd.Series(np.nan, index=df.index)
q50_synthetic = pd.to_numeric(df['adjusted_synthetic_q50'], errors='coerce')

# Use pure probit q50 when not empty, else synthetic
df['ml_estimate'] = q50_pure.where(q50_pure.notna(), q50_synthetic)

# Calculate ML share using the combined estimate
headcount = pd.to_numeric(df[HEADCOUNT_COL], errors='coerce')
df['ml_share_calc'] = df['ml_estimate'] / headcount

# Get values for masking
ml_count = df['ml_estimate']
ml_share = df['ml_share_calc']

# Summary stats
n_pure = q50_pure.notna().sum()
n_synthetic = len(df) - n_pure
print(f"ML estimate source: {n_pure} using pure probit q50, {n_synthetic} using adjusted synthetic q50")
print(f"ML staff (estimate) - min: {ml_count.min():.1f}, max: {ml_count.max():.1f}, median: {ml_count.median():.1f}")
print(f"ML share - min: {ml_share.min():.4f}, max: {ml_share.max():.4f}, median: {ml_share.median():.4f}")

In [None]:
# Create subgroup masks

# 1. All companies (no filter)
mask_all = pd.Series([True] * len(df), index=df.index)

# -----------------------------------------------------------------------------
# Confidence categories based on statistical estimates
# Uses pure probit when available, otherwise adjusted synthetic
# -----------------------------------------------------------------------------
q10_pure = pd.to_numeric(df['q10'], errors='coerce')
q50_pure = pd.to_numeric(df['q50'], errors='coerce')
q90_pure = pd.to_numeric(df['q90'], errors='coerce')
q10_synthetic = pd.to_numeric(df['adjusted_synthetic_q10'], errors='coerce')
q50_synthetic = pd.to_numeric(df['adjusted_synthetic_q50'], errors='coerce')
q90_synthetic = pd.to_numeric(df['adjusted_synthetic_q90'], errors='coerce')

# Use pure probit when available (indicated by q50 not being NaN)
effective_q10 = q10_pure.where(q50_pure.notna(), q10_synthetic)
effective_q50 = q50_pure.where(q50_pure.notna(), q50_synthetic)
effective_q90 = q90_pure.where(q50_pure.notna(), q90_synthetic)

# 2. Probable: q10 > 0 (80% CI excludes zero - confident ML presence)
mask_probable = (effective_q10 > 0)

# 3. Possible: q50 > 0 but q10 = 0 (central estimate positive but uncertain)
mask_possible = (effective_q50 > 0) & (effective_q10 == 0)

# 4. Non-zero: q90 > 0 but q50 = 0 (upper bound positive only)
mask_nonzero = (effective_q90 > 0) & (effective_q50 == 0)

# 5. Not Detected: all zeros (no ML signal)
mask_not_detected = (effective_q90 == 0) & (effective_q50 == 0) & (effective_q10 == 0)

In [None]:
# Create filtered DataFrames
df_all = df[mask_all].copy()
df_probable = df[mask_probable].copy()
df_possible = df[mask_possible].copy()
df_nonzero = df[mask_nonzero].copy()
df_not_detected = df[mask_not_detected].copy()

# Store in dict for easy access
subgroups = {
    'all': df_all,
    'probable': df_probable,
    'possible': df_possible,
    'nonzero': df_nonzero,
    'not_detected': df_not_detected
}

In [None]:
# Summary
print("Subgroup Summary")
print("="*60)
print(f"{'Category':<20} {'Count':>8} {'Criteria'}")
print("-"*60)
print(f"{'All':<20} {len(df_all):>8} All organizations")
print("-"*60)
print("Confidence Categories (mutually exclusive):")
print(f"{'Probable':<20} {len(df_probable):>8} q10 > 0 (CI excludes zero)")
print(f"{'Possible':<20} {len(df_possible):>8} q50 > 0, q10 = 0")
print(f"{'Non-zero':<20} {len(df_nonzero):>8} q90 > 0, q50 = 0")
print(f"{'Not Detected':<20} {len(df_not_detected):>8} all zeros")
print("-"*60)
print(f"{'Categories total':<20} {len(df_probable) + len(df_possible) + len(df_nonzero) + len(df_not_detected):>8} (should equal All)")

In [None]:
# Preview each subgroup
preview_cols = ['organization_name', HEADCOUNT_COL, 'ml_estimate', 'ml_share_calc', 'Subregion']

for name, subgroup_df in subgroups.items():
    if len(subgroup_df) > 0:
        print(f"\n{name.upper()} - Sample (first 5):")
        display(subgroup_df[preview_cols].head())

In [None]:
# Optional: Save outputs
if SAVE_OUTPUTS:
    output_dir = DATA_DIR / 'output'
    output_dir.mkdir(exist_ok=True)
    
    output_files = {
        'all': 'all_orgs.csv',
        'probable': 'orgs_probable.csv',
        'possible': 'orgs_possible.csv',
        'nonzero': 'orgs_nonzero.csv',
        'not_detected': 'orgs_not_detected.csv'
    }
    
    for name, filename in output_files.items():
        filepath = output_dir / filename
        subgroups[name].to_csv(filepath, index=False)
        print(f"Saved {name}: {filepath} (N={len(subgroups[name])})")
else:
    print("Output save skipped (set SAVE_OUTPUTS=True to enable)")

---
## Step 3: Summary Comparison Table

Generate a comprehensive comparison table across all subgroups with:
- Organization counts and percentages
- Employee totals and medians
- ML engineer estimates (q10, q50, q90)
- ML talent percentages
- Size breakdowns
- Regional distributions

In [None]:
# Summary table configuration
PCT_DECIMALS = 1
INDENT = "\u00A0" * 4  # Non-breaking spaces for indentation

# Regions in display order
REGION_ORDER = [
    "Northern America",
    "Western Europe", 
    "Southern Asia",
    "Eastern Asia",
    "Northern Europe",
    "Eastern Europe",
    "Western Asia",
    "South America",
    "Southern Europe",
    "Unknown",
    "Australia and New Zealand",
    "South-Eastern Asia",
    "Central America",
    "Southern Africa",
    "Northern Africa",
    "Western Africa",
    "Caribbean",
]

# Size sections
SIZE_SECTIONS = [
    ("Small (< 100 employees)", lambda hc: hc < 100),
    ("Medium (100-999 employees)", lambda hc: (hc >= 100) & (hc <= 999)),
    ("Large (1,000-9,999 employees)", lambda hc: (hc >= 1000) & (hc <= 9999)),
    ("Giant (≥10,000 employees)", lambda hc: hc >= 10000),
]

In [None]:
# Helper functions for formatting

def format_int_iso(n):
    """Format integers with spaces between groups of three digits."""
    if n is None or (isinstance(n, float) and not np.isfinite(n)) or pd.isna(n):
        return ""
    n = int(n)
    sign = "-" if n < 0 else ""
    s = str(abs(n))
    groups = []
    while s:
        groups.append(s[-3:])
        s = s[:-3]
    return sign + " ".join(reversed(groups))

def pct_string(numer, denom, decimals=PCT_DECIMALS):
    """Format as percentage string."""
    if denom is None or denom == 0 or not np.isfinite(denom) or pd.isna(denom):
        return "n/a"
    pct = 100.0 * float(numer) / float(denom)
    return f"{pct:.{decimals}f}%"

def extract_year(series):
    """Extract year from 'Founded Date' column."""
    dt = pd.to_datetime(series, errors="coerce")
    years = dt.dt.year
    years = years.where((years >= 1700) & (years <= 2100))
    return years.astype("Int64")

def weighted_mean(values, weights):
    """Headcount-weighted mean, ignoring NaNs."""
    mask = values.notna() & weights.notna()
    if not mask.any():
        return None
    v = values[mask].astype(float)
    w = weights[mask].astype(float)
    wsum = float(w.sum())
    if wsum <= 0:
        return None
    return float((v * w).sum()) / wsum

In [None]:
def compute_section_metrics(df_sub, total_n, include_total_employees=False, include_median_year=False):
    """
    Compute metrics for a section (total or size band).
    Returns dict of row_label -> formatted string.
    """
    out = {}
    org_n = len(df_sub)
    
    # Organization N with percentage
    out["Organization N"] = f"{format_int_iso(org_n)} ({pct_string(org_n, total_n)})"
    
    # Total employees (only for Total section)
    hc = pd.to_numeric(df_sub['total_headcount'], errors='coerce')
    total_emp = int(hc.dropna().sum()) if hc.notna().any() else 0
    
    if include_total_employees:
        out["Total employees"] = format_int_iso(total_emp)
    
    # Median founding year (only for Total section)
    if include_median_year and 'Founded Date' in df_sub.columns:
        years = extract_year(df_sub['Founded Date'])
        med_year = int(np.round(years.dropna().median())) if years.notna().any() else None
        out["Median founding year"] = "" if med_year is None else str(med_year)
    
    # Median total employees
    hc_valid = hc.dropna()
    med_emp = int(np.round(hc_valid.median())) if len(hc_valid) > 0 else None
    out["Median total employees"] = "" if med_emp is None else format_int_iso(med_emp)
    
    # ML engineers: sum of q50 with q10-q90 interval
    ml_q50 = pd.to_numeric(df_sub['adjusted_synthetic_q50'], errors='coerce')
    ml_q10 = pd.to_numeric(df_sub['adjusted_synthetic_q10'], errors='coerce')
    ml_q90 = pd.to_numeric(df_sub['adjusted_synthetic_q90'], errors='coerce')
    
    total_q50 = int(ml_q50.dropna().sum()) if ml_q50.notna().any() else 0
    total_q10 = int(ml_q10.dropna().sum()) if ml_q10.notna().any() else 0
    total_q90 = int(ml_q90.dropna().sum()) if ml_q90.notna().any() else 0
    
    out["ML engineers (q50)"] = f"{format_int_iso(total_q50)} ({format_int_iso(total_q10)} - {format_int_iso(total_q90)})"
    
    # Percentage ML talent (weighted by headcount)
    if total_emp > 0:
        base_pct = pct_string(total_q50, total_emp)
        
        # Calculate weighted mean share for interval
        share_q10 = ml_q10 / hc
        share_q90 = ml_q90 / hc
        
        low = weighted_mean(share_q10, hc)
        high = weighted_mean(share_q90, hc)
        
        if low is not None and high is not None:
            interval = f"{100*low:.{PCT_DECIMALS}f}% - {100*high:.{PCT_DECIMALS}f}%"
            out["ML % of total"] = f"{base_pct} ({interval})"
        else:
            out["ML % of total"] = base_pct
    else:
        out["ML % of total"] = "n/a"
    
    return out

In [None]:
def compute_regional_breakdown(df_sub, total_n):
    """
    Compute regional breakdown for orgs and employees.
    Returns dict of row_label -> formatted string.
    """
    out = {}
    hc = pd.to_numeric(df_sub['total_headcount'], errors='coerce')
    
    for region in REGION_ORDER:
        mask_r = (df_sub['Subregion'] == region)
        
        # Org count
        count = int(mask_r.sum())
        out[f"{region} (orgs)"] = f"{format_int_iso(count)} ({pct_string(count, total_n)})"
        
        # Employee sum  
        emp_sum = int(hc[mask_r].dropna().sum()) if mask_r.any() else 0
        out[f"{region} (employees)"] = format_int_iso(emp_sum)
    
    return out

In [None]:
def summarize_subgroup(df_sub):
    """
    Build complete summary for one subgroup.
    Returns nested dict: section -> row_label -> value
    """
    summary = {}
    total_n = len(df_sub)
    hc = pd.to_numeric(df_sub['total_headcount'], errors='coerce')
    
    # Total section
    summary["Total"] = compute_section_metrics(
        df_sub, total_n, 
        include_total_employees=True, 
        include_median_year=True
    )
    
    # Size sections
    for size_label, cond in SIZE_SECTIONS:
        mask = hc.notna() & cond(hc)
        summary[size_label] = compute_section_metrics(
            df_sub[mask], total_n,
            include_total_employees=False,
            include_median_year=False
        )
    
    # Regional breakdown
    summary["Regions"] = compute_regional_breakdown(df_sub, total_n)
    
    return summary

In [None]:
def build_comparison_table(subgroups_dict):
    """
    Build comparison table across all subgroups.
    Returns DataFrame with Characteristic as index and subgroups as columns.
    """
    # Build row index structure: (display_name, row_type, key_for_lookup)
    rows = []
    
    # Total section
    rows.append(("Total", "header", None))
    for lbl in ["Organization N", "Total employees", "Median founding year", 
                "Median total employees", "ML engineers (q50)", "ML % of total"]:
        rows.append((f"{INDENT}{lbl}", "row", ("Total", lbl)))
    
    # Size sections
    for size_label, _ in SIZE_SECTIONS:
        rows.append((size_label, "header", None))
        for lbl in ["Organization N", "Median total employees", 
                    "ML engineers (q50)", "ML % of total"]:
            rows.append((f"{INDENT}{lbl}", "row", (size_label, lbl)))
    
    # Regions section - orgs first, then employees
    rows.append(("Regions (orgs)", "header", None))
    for r in REGION_ORDER:
        # Display without "(orgs)", but lookup key still has it
        rows.append((f"{INDENT}{r}", "row", ("Regions", f"{r} (orgs)")))
    
    rows.append(("Regions (employees)", "header", None))
    for r in REGION_ORDER:
        # Display without "(employees)", but lookup key still has it
        rows.append((f"{INDENT}{r}", "row", ("Regions", f"{r} (employees)")))
    
    # Compute summaries for each subgroup
    summaries = {name: summarize_subgroup(df) for name, df in subgroups_dict.items()}
    
    # Build data columns
    data = {}
    display_names = {
        'all': 'All',
        'probable': 'Probable',
        'possible': 'Possible',
        'nonzero': 'Non-zero',
        'not_detected': 'Not Detected'
    }
    
    for name, summ in summaries.items():
        col_vals = []
        for display, row_type, key in rows:
            if row_type == "header":
                col_vals.append("")
            else:
                section, row_lbl = key
                col_vals.append(summ.get(section, {}).get(row_lbl, ""))
        data[display_names.get(name, name)] = col_vals
    
    # Create DataFrame
    index_display = [r[0] for r in rows]
    summary_df = pd.DataFrame(data, index=index_display)
    summary_df.index.name = "Characteristic"
    
    return summary_df

In [None]:
# Build the comparison table
summary_table = build_comparison_table(subgroups)
print(f"Summary table: {len(summary_table)} rows x {len(summary_table.columns)} columns")
summary_table

In [None]:
# Optional: Save summary table
if SAVE_OUTPUTS:
    output_path = DATA_DIR / 'output' / 'combined_summary.csv'
    summary_table.to_csv(output_path)
    print(f"Saved summary table to {output_path}")
else:
    print("Summary table save skipped (set SAVE_OUTPUTS=True to enable)")

### Step 3b: Individual Company Table

Detailed table listing each organization with:
- Company info (name, founded year, country, headcount)
- Individual ML estimator values
- Debiased log-median ML talent estimate with 80% CI
- ML share percentage with 80% CI
- Category classification (Enterprise/Mid-Scale/Boutique or "-")

In [None]:
def create_company_table(df_input, subgroup_name=""):
    """
    Create detailed company table with ML estimates.
    
    Args:
        df_input: DataFrame with organization data
        subgroup_name: Name of the subgroup for display
    
    Returns:
        DataFrame with formatted company data
    """
    # Estimator columns in display order
    estimator_cols = [
        'filter_broad_yes_strict_no', 
        'filter_strict_no', 
        'filter_broad_yes',
        'claude_total_accepted', 
        'gpt5_total_accepted', 
        'gemini_total_accepted'
    ]
    
    rows = []
    for _, row in df_input.iterrows():
        # Extract founding year
        founded_year = ""
        if 'Founded Date' in row.index and pd.notna(row['Founded Date']):
            try:
                dt = pd.to_datetime(row['Founded Date'], errors='coerce')
                if pd.notna(dt):
                    founded_year = str(dt.year)
            except (ValueError, TypeError):
                pass
        
        # Get individual estimator values (show 0 instead of -)
        estimator_values = []
        for col in estimator_cols:
            if col in row.index:
                val = pd.to_numeric(row[col], errors='coerce')
                if pd.notna(val):
                    estimator_values.append(str(int(val)))
                else:
                    estimator_values.append("-")
            else:
                estimator_values.append("-")
        
        # Determine if using synthetic estimate
        q50_pure = pd.to_numeric(row.get('q50', np.nan), errors='coerce')
        use_synthetic = pd.isna(q50_pure)
        
        # Get ML estimate (q50) and CI bounds
        if use_synthetic:
            ml_q50 = pd.to_numeric(row.get('adjusted_synthetic_q50', np.nan), errors='coerce')
            ml_q10 = pd.to_numeric(row.get('adjusted_synthetic_q10', np.nan), errors='coerce')
            ml_q90 = pd.to_numeric(row.get('adjusted_synthetic_q90', np.nan), errors='coerce')
        else:
            ml_q50 = q50_pure
            ml_q10 = pd.to_numeric(row.get('q10', np.nan), errors='coerce')
            ml_q90 = pd.to_numeric(row.get('q90', np.nan), errors='coerce')
        
        # Format ML talent estimate with CI
        if pd.notna(ml_q50):
            ml_str = f"{format_int_iso(int(ml_q50))}"
            if pd.notna(ml_q10) and pd.notna(ml_q90):
                ml_str += f" ({format_int_iso(int(ml_q10))} - {format_int_iso(int(ml_q90))})"
            if use_synthetic:
                ml_str += " *"  # Mark synthetic estimates
        else:
            ml_str = "-"
        
        # Get headcount and calculate ML share
        headcount = pd.to_numeric(row.get('total_headcount', np.nan), errors='coerce')
        
        if pd.notna(ml_q50) and pd.notna(headcount) and headcount > 0:
            ml_pct = 100.0 * ml_q50 / headcount
            ml_pct_str = f"{ml_pct:.2f}%"
            
            if pd.notna(ml_q10) and pd.notna(ml_q90):
                pct_low = 100.0 * ml_q10 / headcount
                pct_high = 100.0 * ml_q90 / headcount
                ml_pct_str += f" ({pct_low:.2f}% - {pct_high:.2f}%)"
        else:
            ml_pct_str = "-"
        
        # Determine confidence category using centralized function
        category = assign_confidence_category(ml_q10, ml_q50, ml_q90)
        
        rows.append({
            'Company Name': row.get('organization_name', ''),
            'Founded': founded_year,
            'Country': row.get('Country', ''),
            'Total Staff (LinkedIn)': format_int_iso(int(headcount)) if pd.notna(headcount) else "-",
            'Individual Estimates [broad+strict, strict, broad, claude, gpt5, gemini]': f"[{', '.join(estimator_values)}]",
            'ML Talent q50 (q10 - q90)': ml_str,
            'ML % of Total': ml_pct_str,
            'Category': category
        })
    
    return pd.DataFrame(rows)


# Note for table display
print("* = Synthetic estimate (adjusted_synthetic_q50 used when pure probit q50 is empty)")

In [None]:
# Create company table for Probable category (q10 > 0), sorted by ML talent estimate (descending)
company_table = create_company_table(df_probable)

# Sort by ML talent (need to extract numeric value for sorting)
def extract_ml_for_sort(ml_str):
    """Extract numeric ML value for sorting."""
    if ml_str == "-":
        return 0
    # Extract first number (the q50 estimate)
    match = re.match(r'([\d\s]+)', ml_str.replace(' ', ''))
    if match:
        return int(match.group(1).replace(' ', ''))
    return 0

company_table['_sort_key'] = company_table['ML Talent q50 (q10 - q90)'].apply(extract_ml_for_sort)
company_table = company_table.sort_values('_sort_key', ascending=False).drop(columns=['_sort_key'])

print(f"Company Table — Probable ({len(company_table)} organizations)")
print("Sorted by ML Talent estimate (descending)")
print()
company_table

In [None]:
# Optional: Save company table
if SAVE_OUTPUTS:
    output_path = DATA_DIR / 'output' / 'company_table_probable.csv'
    company_table.to_csv(output_path, index=False)
    print(f"Saved company table to {output_path}")

### Step 3c: Individual Company Table — All Organizations

Same detailed table as above, but for all 403 organizations in the dataset.

In [None]:
# Create company table for ALL organizations, sorted by ML talent estimate (descending)
company_table_all = create_company_table(df_all)

# Sort by ML talent (reuse same sorting logic)
company_table_all['_sort_key'] = company_table_all['ML Talent q50 (q10 - q90)'].apply(extract_ml_for_sort)
company_table_all = company_table_all.sort_values('_sort_key', ascending=False).drop(columns=['_sort_key'])

print(f"Company Table — All Organizations ({len(company_table_all)} organizations)")
print("Sorted by ML Talent estimate (descending)")
print()
company_table_all

In [None]:
# Optional: Save company table for all organizations
if SAVE_OUTPUTS:
    output_path = DATA_DIR / 'output' / 'company_table_all.csv'
    company_table_all.to_csv(output_path, index=False)
    print(f"Saved company table to {output_path}")

---
## Step 4: ML Estimates Visualization

Plot individual estimates and confidence intervals for organizations with Probable ML presence (q10 > 0).

- Individual markers for each raw estimator (keyword filters + LLMs)
- Confidence intervals: 
  - **Pure Probit 80% CI** (q10, q50, q90) when available
  - **Adjusted Synthetic Probit 80% CI** when pure probit is empty

In [None]:
def create_ml_estimates_plot(df_plot, figsize=(16, 8)):
    """
    Create visualization of ML estimates for organizations.
    
    Design approach:
    - Keyword filters shown as small, muted gray markers (background context)
    - LLM estimates shown as distinct colored markers (primary focus)
    - Confidence intervals with clear visual distinction (pure probit vs synthetic)
    - Clean legend at bottom, outside plot area
    
    Args:
        df_plot: DataFrame with organizations to plot
        figsize: Figure size tuple
    
    Returns:
        fig, ax, df_sorted
    """
    # Estimator columns
    filter_cols = ['filter_broad_yes', 'filter_strict_no', 'filter_broad_yes_strict_no']
    llm_cols = ['gemini_total_accepted', 'claude_total_accepted', 'gpt5_total_accepted']
    all_estimator_cols = [c for c in filter_cols + llm_cols if c in df_plot.columns]
    
    # Make a copy and prepare data
    df_sorted = df_plot.copy()
    
    # Determine which CI to use for each org
    df_sorted['_use_pure_probit'] = pd.to_numeric(df_sorted['q50'], errors='coerce').notna()
    
    # Get central estimate and bounds for each org
    df_sorted['_central'] = np.where(
        df_sorted['_use_pure_probit'],
        pd.to_numeric(df_sorted['q50'], errors='coerce'),
        pd.to_numeric(df_sorted['adjusted_synthetic_q50'], errors='coerce')
    )
    df_sorted['_lower'] = np.where(
        df_sorted['_use_pure_probit'],
        pd.to_numeric(df_sorted['q10'], errors='coerce'),
        pd.to_numeric(df_sorted['adjusted_synthetic_q10'], errors='coerce')
    )
    df_sorted['_upper'] = np.where(
        df_sorted['_use_pure_probit'],
        pd.to_numeric(df_sorted['q90'], errors='coerce'),
        pd.to_numeric(df_sorted['adjusted_synthetic_q90'], errors='coerce')
    )
    
    # Sort by central estimate (consistent with create_ml_estimates_plot_all_orgs)
    df_sorted['_sort_key'] = df_sorted['_central'].fillna(0)
    df_sorted = df_sorted.sort_values('_sort_key').reset_index(drop=True)
    
    fig, ax = plt.subplots(figsize=figsize)
    x = np.arange(len(df_sorted))
    
    # Offset for jittering points
    offset_step = 0.10
    
    # -------------------------------------------------------------------------
    # Layer 1: Keyword filter estimates (background, muted)
    # -------------------------------------------------------------------------
    filter_handles = []
    for i, col in enumerate(filter_cols):
        if col not in df_sorted.columns:
            continue
        y = pd.to_numeric(df_sorted[col], errors='coerce').values
        mask = np.isfinite(y) & (y > 0)
        x_pos = x + (i - 1) * offset_step
        
        style = ESTIMATOR_STYLES.get(col, {'color': PALETTE['gray'], 'marker': 'o', 'size': 24})
        sc = ax.scatter(
            x_pos[mask], y[mask],
            s=style['size'], marker=style['marker'],
            c=style['color'], alpha=0.35,
            linewidths=0, zorder=1,
            label=ESTIMATOR_LABELS.get(col, col)
        )
        filter_handles.append((sc, ESTIMATOR_LABELS.get(col, col)))
    
    # -------------------------------------------------------------------------
    # Layer 2: LLM estimates (foreground, distinctive)
    # -------------------------------------------------------------------------
    llm_handles = []
    for j, col in enumerate(llm_cols):
        if col not in df_sorted.columns:
            continue
        y = pd.to_numeric(df_sorted[col], errors='coerce').values
        mask = np.isfinite(y) & (y > 0)
        x_pos = x + (j - 1) * offset_step
        
        style = ESTIMATOR_STYLES.get(col, {'color': PALETTE['blue'], 'marker': 'D', 'size': 36})
        sc = ax.scatter(
            x_pos[mask], y[mask],
            s=style['size'], marker=style['marker'],
            c=style['color'], alpha=0.85,
            edgecolors='white', linewidths=0.5, zorder=2,
            label=ESTIMATOR_LABELS.get(col, col)
        )
        llm_handles.append((sc, ESTIMATOR_LABELS.get(col, col)))
    
    # -------------------------------------------------------------------------
    # Layer 3: Confidence intervals (top layer)
    # -------------------------------------------------------------------------
    central = df_sorted['_central'].values
    lower = df_sorted['_lower'].values
    upper = df_sorted['_upper'].values
    use_pure = df_sorted['_use_pure_probit'].values
    
    # Epsilon for log scale
    eps = 0.5
    lower = np.maximum(lower, eps)
    central = np.maximum(central, eps)
    
    # Compute error bars
    yerr_lower = np.clip(central - lower, 0, None)
    yerr_upper = np.clip(upper - central, 0, None)
    
    # Mask for valid central estimates
    mask_valid = np.isfinite(central) & (central > 0)
    
    ci_handles = []
    
    # Pure Probit CI (forest green)
    mask_pure = mask_valid & use_pure
    if np.any(mask_pure):
        err_pure = ax.errorbar(
            x[mask_pure], central[mask_pure],
            yerr=np.vstack([yerr_lower[mask_pure], yerr_upper[mask_pure]]),
            fmt='o', 
            mfc='white', mec=COLORS['ci_pure_probit'], mew=1.8, ms=5,
            ecolor=COLORS['ci_pure_probit'], elinewidth=1.2, capsize=2.5, capthick=1.2,
            zorder=4
        )
        ci_handles.append((err_pure, 'Pure Probit 80% CI'))
    
    # Adjusted Synthetic CI (purple)
    mask_synthetic = mask_valid & (~use_pure)
    if np.any(mask_synthetic):
        err_synth = ax.errorbar(
            x[mask_synthetic], central[mask_synthetic],
            yerr=np.vstack([yerr_lower[mask_synthetic], yerr_upper[mask_synthetic]]),
            fmt='o', 
            mfc='white', mec=COLORS['ci_adjusted_synthetic'], mew=1.8, ms=5,
            ecolor=COLORS['ci_adjusted_synthetic'], elinewidth=1.2, capsize=2.5, capthick=1.2,
            zorder=4
        )
        ci_handles.append((err_synth, 'Adjusted Synthetic 80% CI'))
    
    # -------------------------------------------------------------------------
    # Axis formatting
    # -------------------------------------------------------------------------
    format_log_axis(ax, axis='y', limits=(1, 10000))
    
    ax.set_xlabel('Organizations (sorted by ML estimate)', fontsize=FONT_SIZES['axis_label'])
    ax.set_ylabel('Estimated ML Talent', fontsize=FONT_SIZES['axis_label'])
    ax.set_title('ML Talent Estimates by Organization', fontsize=FONT_SIZES['title'], fontweight='medium', pad=10)
    
    # X-axis labels (organization names)
    org_col = 'organization_name' if 'organization_name' in df_sorted.columns else None
    if org_col:
        ax.set_xticks(x)
        ax.set_xticklabels(
            df_sorted[org_col].astype(str).tolist(), 
            rotation=45, ha='right', 
            fontsize=FONT_SIZES['org_label']
        )
    
    # -------------------------------------------------------------------------
    # Legend — organized by category in 3 columns, positioned below plot
    # -------------------------------------------------------------------------
    
    # Build legend elements for 3-column layout (interleaved for proper alignment)
    # With ncol=3, matplotlib fills row-by-row, so we interleave:
    # Row 1: CI header, LLM header, Keyword header
    # Row 2: Pure Probit, Gemini, Broad Yes
    # Row 3: Adjusted Synthetic, Claude, Strict No
    # Row 4: (spacer), GPT-5-mini, Broad+Strict
    
    # Prepare CI items with correct colors from COLORS dict
    ci_items = []
    for handle, label in ci_handles:
        color = COLORS['ci_pure_probit'] if 'Pure' in label else COLORS['ci_adjusted_synthetic']
        ci_items.append(Line2D(
            [0], [0], marker='o', color=color,
            markerfacecolor='white', markeredgecolor=color,
            markeredgewidth=1.8, markersize=6,
            linestyle='-', linewidth=1.2,
            label=f'  {label}'
        ))
    
    # Prepare LLM items
    llm_items = []
    for handle, label in llm_handles:
        style = [s for c, s in ESTIMATOR_STYLES.items() if ESTIMATOR_LABELS.get(c) == label]
        if style:
            s = style[0]
            llm_items.append(Line2D(
                [0], [0], marker=s['marker'], color='w',
                markerfacecolor=s['color'], markeredgecolor='white',
                markersize=7, linestyle='None',
                label=f'  {label}'
            ))
    
    # Prepare Keyword items
    keyword_items = []
    for handle, label in filter_handles:
        style = [s for c, s in ESTIMATOR_STYLES.items() if ESTIMATOR_LABELS.get(c) == label]
        if style:
            s = style[0]
            keyword_items.append(Line2D(
                [0], [0], marker=s['marker'], color='w',
                markerfacecolor=s['color'], markeredgecolor='none',
                markersize=5, linestyle='None', alpha=0.5,
                label=f'  {label}'
            ))
    
    # Section headers
    ci_header = Line2D([0], [0], color='none', label='Confidence Intervals:')
    llm_header = Line2D([0], [0], color='none', label='LLM Estimates:')
    keyword_header = Line2D([0], [0], color='none', label='Keyword Filters:')
    spacer = Line2D([0], [0], color='none', linestyle='None', label=' ')  # Invisible spacer
    
    # Padding moved to column building below
    
    # Build legend as 3 columns (matplotlib ncol=3 fills column-by-column)
    # Pad all lists to same length for even columns
    max_len = max(len(ci_items), len(llm_items), len(keyword_items))
    while len(ci_items) < max_len:
        ci_items.append(spacer)
    while len(llm_items) < max_len:
        llm_items.append(spacer)
    while len(keyword_items) < max_len:
        keyword_items.append(spacer)
    
    # Concatenate: header + items for each column
    legend_elements = []
    legend_elements.append(ci_header)
    legend_elements.extend(ci_items)
    legend_elements.append(llm_header)
    legend_elements.extend(llm_items)
    legend_elements.append(keyword_header)
    legend_elements.extend(keyword_items)
    
    # Position legend below plot
    ax.legend(
        handles=legend_elements,
        loc='upper center',
        bbox_to_anchor=(0.5, -0.35),
        ncol=3,
        fontsize=FONT_SIZES['legend'],
        frameon=False,
        columnspacing=2.5,
        handletextpad=0.5,
    )
    
    # Subtle grid
    ax.grid(True, which='major', alpha=0.20, linewidth=0.4, color=COLORS['gridline'])
    ax.set_axisbelow(True)
    
    # Use fixed margins only - tight_layout can cause figure size explosion
    plt.subplots_adjust(left=0.05, right=0.95, top=0.92, bottom=0.35)
    
    return fig, ax, df_sorted

In [None]:
# Create the plot for Probable organizations (q10 > 0)
fig, ax, df_plot_sorted = create_ml_estimates_plot(df_probable, figsize=(16, 8))

# Print summary
n_pure = df_plot_sorted['_use_pure_probit'].sum()
n_synthetic = len(df_plot_sorted) - n_pure
print(f"Plot Summary:")
print(f"  Total organizations: {len(df_plot_sorted)}")
print(f"  Using Pure Probit CI: {n_pure}")
print(f"  Using Adjusted Synthetic CI: {n_synthetic}")

# Save to output directory
if SAVE_OUTPUTS:
    output_dir = DATA_DIR / 'output'
    output_dir.mkdir(exist_ok=True)
    fig.savefig(output_dir / 'ml_estimates_probable.png', dpi=200)
    print(f"\nSaved: {output_dir / 'ml_estimates_probable.png'}")

plt.close(fig)

### Step 4b: ML Estimates Visualization — Confident Estimates

Horizontal plot of ML estimates for organizations where the 80% CI excludes zero (q10 > 0).

- **Statistical filter**: Only includes organizations where we're confident ML talent exists
- **Rationale**: When q10 > 0, the confidence interval doesn't include zero, meaning we're uncertain about *how many* ML engineers, not *whether* there are any
- **Axis layout**: Organizations on Y-axis, ML estimates on X-axis (for readability)

In [None]:
def create_ml_estimates_plot_all_orgs(df_plot, figsize=(16, 8), title_suffix=""):
    """
    Create visualization of ML estimates, filtering to confident organizations.
    
    NOTE: This function filters df_plot to only show organizations where q10 > 0
    (80% CI excludes zero). The output is identical to create_ml_estimates_plot()
    called with df_probable, but this function:
    1. Accepts the full dataset (df_all) as input
    2. Performs internal filtering and reports filtering statistics
    3. Useful for documenting how many orgs were excluded
    
    For pre-filtered data, use create_ml_estimates_plot() instead.
    
    Design approach:
    - Organizations on X-axis, ML estimates on Y-axis (standard orientation)
    - Filters to companies where q10 > 0 (80% CI excludes zero)
    - Keyword filters shown as small, muted gray markers (background context)
    - LLM estimates shown as distinct colored markers (primary focus)
    - Confidence intervals with clear visual distinction (pure probit vs synthetic)
    
    Args:
        df_plot: DataFrame with ALL organization data (will be filtered internally)
        figsize: Figure size tuple
        title_suffix: Optional suffix for plot title
    
    Returns:
        fig, ax, df_sorted (filtered to q10 > 0 only)
    """
    # Estimator columns
    filter_cols = ['filter_broad_yes', 'filter_strict_no', 'filter_broad_yes_strict_no']
    llm_cols = ['gemini_total_accepted', 'claude_total_accepted', 'gpt5_total_accepted']
    
    # Make a copy and prepare data
    df_sorted = df_plot.copy()
    
    # Determine which CI to use for each org
    df_sorted['_use_pure_probit'] = pd.to_numeric(df_sorted['q50'], errors='coerce').notna()
    
    # Get central estimate and bounds for each org
    df_sorted['_central'] = np.where(
        df_sorted['_use_pure_probit'],
        pd.to_numeric(df_sorted['q50'], errors='coerce'),
        pd.to_numeric(df_sorted['adjusted_synthetic_q50'], errors='coerce')
    )
    df_sorted['_lower'] = np.where(
        df_sorted['_use_pure_probit'],
        pd.to_numeric(df_sorted['q10'], errors='coerce'),
        pd.to_numeric(df_sorted['adjusted_synthetic_q10'], errors='coerce')
    )
    df_sorted['_upper'] = np.where(
        df_sorted['_use_pure_probit'],
        pd.to_numeric(df_sorted['q90'], errors='coerce'),
        pd.to_numeric(df_sorted['adjusted_synthetic_q90'], errors='coerce')
    )
    
    # Filter to companies where q10 > 0 (CI excludes zero)
    # This is the statistically principled cut: we're confident ML talent exists
    mask_ci_excludes_zero = df_sorted['_lower'] > 0
    df_sorted = df_sorted[mask_ci_excludes_zero].copy()
    
    # Sort by central estimate
    df_sorted['_sort_key'] = df_sorted['_central'].fillna(0)
    df_sorted = df_sorted.sort_values('_sort_key').reset_index(drop=True)
    
    fig, ax = plt.subplots(figsize=figsize)
    x = np.arange(len(df_sorted))
    
    # Offset for jittering points
    offset_step = 0.10
    
    # -------------------------------------------------------------------------
    # Layer 1: Keyword filter estimates (background, muted)
    # -------------------------------------------------------------------------
    filter_handles = []
    for i, col in enumerate(filter_cols):
        if col not in df_sorted.columns:
            continue
        y = pd.to_numeric(df_sorted[col], errors='coerce').values
        mask = np.isfinite(y) & (y > 0)
        x_pos = x + (i - 1) * offset_step
        
        style = ESTIMATOR_STYLES.get(col, {'color': PALETTE['gray'], 'marker': 'o', 'size': 24})
        sc = ax.scatter(
            x_pos[mask], y[mask],
            s=style['size'], marker=style['marker'],
            c=style['color'], alpha=0.35,
            linewidths=0, zorder=1,
            label=ESTIMATOR_LABELS.get(col, col)
        )
        filter_handles.append((sc, ESTIMATOR_LABELS.get(col, col)))
    
    # -------------------------------------------------------------------------
    # Layer 2: LLM estimates (foreground, distinctive)
    # -------------------------------------------------------------------------
    llm_handles = []
    for j, col in enumerate(llm_cols):
        if col not in df_sorted.columns:
            continue
        y = pd.to_numeric(df_sorted[col], errors='coerce').values
        mask = np.isfinite(y) & (y > 0)
        x_pos = x + (j - 1) * offset_step
        
        style = ESTIMATOR_STYLES.get(col, {'color': PALETTE['blue'], 'marker': 'D', 'size': 36})
        sc = ax.scatter(
            x_pos[mask], y[mask],
            s=style['size'], marker=style['marker'],
            c=style['color'], alpha=0.85,
            edgecolors='white', linewidths=0.5, zorder=2,
            label=ESTIMATOR_LABELS.get(col, col)
        )
        llm_handles.append((sc, ESTIMATOR_LABELS.get(col, col)))
    
    # -------------------------------------------------------------------------
    # Layer 3: Confidence intervals (top layer)
    # -------------------------------------------------------------------------
    central = df_sorted['_central'].values
    lower = df_sorted['_lower'].values
    upper = df_sorted['_upper'].values
    use_pure = df_sorted['_use_pure_probit'].values
    
    # Epsilon for log scale
    eps = 0.5
    lower = np.maximum(lower, eps)
    central = np.maximum(central, eps)
    
    # Compute error bars
    yerr_lower = np.clip(central - lower, 0, None)
    yerr_upper = np.clip(upper - central, 0, None)
    
    # Mask for valid central estimates
    mask_valid = np.isfinite(central) & (central > 0)
    
    ci_handles = []
    
    # Pure Probit CI (forest green)
    mask_pure = mask_valid & use_pure
    if np.any(mask_pure):
        err_pure = ax.errorbar(
            x[mask_pure], central[mask_pure],
            yerr=np.vstack([yerr_lower[mask_pure], yerr_upper[mask_pure]]),
            fmt='o', 
            mfc='white', mec=COLORS['ci_pure_probit'], mew=1.8, ms=5,
            ecolor=COLORS['ci_pure_probit'], elinewidth=1.2, capsize=2.5, capthick=1.2,
            zorder=4
        )
        ci_handles.append((err_pure, 'Pure Probit 80% CI'))
    
    # Adjusted Synthetic CI (purple)
    mask_synthetic = mask_valid & (~use_pure)
    if np.any(mask_synthetic):
        err_synth = ax.errorbar(
            x[mask_synthetic], central[mask_synthetic],
            yerr=np.vstack([yerr_lower[mask_synthetic], yerr_upper[mask_synthetic]]),
            fmt='o', 
            mfc='white', mec=COLORS['ci_adjusted_synthetic'], mew=1.8, ms=5,
            ecolor=COLORS['ci_adjusted_synthetic'], elinewidth=1.2, capsize=2.5, capthick=1.2,
            zorder=4
        )
        ci_handles.append((err_synth, 'Adjusted Synthetic 80% CI'))
    
    # -------------------------------------------------------------------------
    # Axis formatting
    # -------------------------------------------------------------------------
    format_log_axis(ax, axis='y', limits=(1, 10000))
    
    ax.set_xlabel('Organizations (sorted by ML estimate)', fontsize=FONT_SIZES['axis_label'])
    ax.set_ylabel('Estimated ML Talent', fontsize=FONT_SIZES['axis_label'])
    ax.set_title(f'ML Talent Estimates by Organization{title_suffix}', fontsize=FONT_SIZES['title'], fontweight='medium', pad=10)
    
    # X-axis labels (organization names)
    org_col = 'organization_name' if 'organization_name' in df_sorted.columns else None
    if org_col:
        ax.set_xticks(x)
        ax.set_xticklabels(
            df_sorted[org_col].astype(str).tolist(), 
            rotation=45, ha='right', 
            fontsize=FONT_SIZES['org_label']
        )
    
    # -------------------------------------------------------------------------
    # Legend — organized by category in 3 columns, positioned below plot
    # -------------------------------------------------------------------------
    
    # Build legend elements for 3-column layout (interleaved for proper alignment)
    # With ncol=3, matplotlib fills row-by-row, so we interleave:
    # Row 1: CI header, LLM header, Keyword header
    # Row 2: Pure Probit, Gemini, Broad Yes
    # Row 3: Adjusted Synthetic, Claude, Strict No
    # Row 4: (spacer), GPT-5-mini, Broad+Strict
    
    # Prepare CI items with correct colors from COLORS dict
    ci_items = []
    for handle, label in ci_handles:
        color = COLORS['ci_pure_probit'] if 'Pure' in label else COLORS['ci_adjusted_synthetic']
        ci_items.append(Line2D(
            [0], [0], marker='o', color=color,
            markerfacecolor='white', markeredgecolor=color,
            markeredgewidth=1.8, markersize=6,
            linestyle='-', linewidth=1.2,
            label=f'  {label}'
        ))
    
    # Prepare LLM items
    llm_items = []
    for handle, label in llm_handles:
        style = [s for c, s in ESTIMATOR_STYLES.items() if ESTIMATOR_LABELS.get(c) == label]
        if style:
            s = style[0]
            llm_items.append(Line2D(
                [0], [0], marker=s['marker'], color='w',
                markerfacecolor=s['color'], markeredgecolor='white',
                markersize=7, linestyle='None',
                label=f'  {label}'
            ))
    
    # Prepare Keyword items
    keyword_items = []
    for handle, label in filter_handles:
        style = [s for c, s in ESTIMATOR_STYLES.items() if ESTIMATOR_LABELS.get(c) == label]
        if style:
            s = style[0]
            keyword_items.append(Line2D(
                [0], [0], marker=s['marker'], color='w',
                markerfacecolor=s['color'], markeredgecolor='none',
                markersize=5, linestyle='None', alpha=0.5,
                label=f'  {label}'
            ))
    
    # Section headers
    ci_header = Line2D([0], [0], color='none', label='Confidence Intervals:')
    llm_header = Line2D([0], [0], color='none', label='LLM Estimates:')
    keyword_header = Line2D([0], [0], color='none', label='Keyword Filters:')
    spacer = Line2D([0], [0], color='none', linestyle='None', label=' ')  # Invisible spacer
    
    # Padding moved to column building below
    
    # Build legend as 3 columns (matplotlib ncol=3 fills column-by-column)
    # Pad all lists to same length for even columns
    max_len = max(len(ci_items), len(llm_items), len(keyword_items))
    while len(ci_items) < max_len:
        ci_items.append(spacer)
    while len(llm_items) < max_len:
        llm_items.append(spacer)
    while len(keyword_items) < max_len:
        keyword_items.append(spacer)
    
    # Concatenate: header + items for each column
    legend_elements = []
    legend_elements.append(ci_header)
    legend_elements.extend(ci_items)
    legend_elements.append(llm_header)
    legend_elements.extend(llm_items)
    legend_elements.append(keyword_header)
    legend_elements.extend(keyword_items)
    
    # Position legend below plot
    ax.legend(
        handles=legend_elements,
        loc='upper center',
        bbox_to_anchor=(0.5, -0.35),
        ncol=3,
        fontsize=FONT_SIZES['legend'],
        frameon=False,
        columnspacing=2.5,
        handletextpad=0.5,
    )
    
    # Subtle grid
    ax.grid(True, which='major', alpha=0.20, linewidth=0.4, color=COLORS['gridline'])
    ax.set_axisbelow(True)
    
    # Use fixed margins
    plt.subplots_adjust(left=0.05, right=0.95, top=0.92, bottom=0.35)
    
    return fig, ax, df_sorted

In [None]:
# Create the plot for ALL organizations with confident estimates (q10 > 0)
fig_all, ax_all, df_all_sorted = create_ml_estimates_plot_all_orgs(
    df_all, 
    figsize=(16, 8),
    title_suffix=" — Confident Estimates (q10 > 0)"
)

# Print summary
n_total = len(df_all)
n_plotted = len(df_all_sorted)
n_excluded = n_total - n_plotted
n_pure = df_all_sorted['_use_pure_probit'].sum()
n_synthetic = n_plotted - n_pure

print(f"Plot Summary:")
print(f"  Total organizations: {n_total}")
print(f"  Plotted (q10 > 0, CI excludes zero): {n_plotted}")
print(f"  Excluded (q10 = 0, CI includes zero): {n_excluded}")
print(f"  Using Pure Probit CI: {n_pure}")
print(f"  Using Adjusted Synthetic CI: {n_synthetic}")

# Save to output directory
if SAVE_OUTPUTS:
    fig_all.savefig(DATA_DIR / 'output' / 'ml_estimates_all_orgs.png', dpi=200, bbox_inches='tight')
    print(f"\nSaved: {DATA_DIR / 'output' / 'ml_estimates_all_orgs.png'}")

plt.close(fig_all)

### Step 4c: Summary Visualizations

Two summary plots showing the distribution of ML estimate confidence across all 403 organizations:

1. **Filtering Funnel**: Shows how organizations filter down from total to confident estimates
2. **Regional Breakdown**: Stacked bar chart showing confidence levels by subregion

In [None]:
# =============================================================================
# Summary Visualizations: Filtering Funnel & Regional Breakdown
# =============================================================================

# Use the effective estimates already computed in Cell 17
# df already has _q10, _q50, _q90 from mask definitions, but let's recompute for clarity
_q10 = effective_q10
_q50 = effective_q50
_q90 = effective_q90

# Create 4 categories using centralized function
df['_category'] = [assign_confidence_category(q10, q50, q90) 
                   for q10, q50, q90 in zip(_q10, _q50, _q90)]

# Use COLORS from Cell 1 design system for consistency

# -----------------------------------------------------------------------------
# Plot 1: Filtering Funnel (4 stages)
# -----------------------------------------------------------------------------
fig_funnel, ax_funnel = plt.subplots(figsize=(10, 6))

stages = ['Total\nOrganizations', 'Non-zero\n(q90 > 0)', 'Possible\n(q50 > 0)', 'Probable\n(q10 > 0)']
values = [
    len(df),
    (_q90 > 0).sum(),
    (_q50 > 0).sum(),
    (_q10 > 0).sum()
]

colors_funnel = [COLORS['not_detected'], COLORS['nonzero'], 
                 COLORS['possible'], COLORS['probable']]
y_pos = np.arange(len(stages))
bars = ax_funnel.barh(y_pos, values, color=colors_funnel, edgecolor='white', linewidth=1, height=0.7)

for i, (bar, val) in enumerate(zip(bars, values)):
    pct = val / values[0] * 100
    ax_funnel.annotate(f'{val} ({pct:.0f}%)', xy=(val + 5, bar.get_y() + bar.get_height()/2),
                       va='center', fontsize=FONT_SIZES['annotation'], fontweight='medium')

for i in range(1, len(values)):
    drop = values[i-1] - values[i]
    drop_pct = drop / values[i-1] * 100
    ax_funnel.annotate(f'−{drop} ({drop_pct:.0f}%)', 
                       xy=(values[i-1] - drop/2, (y_pos[i-1] + y_pos[i])/2),
                       ha='center', va='center', fontsize=FONT_SIZES['annotation']-1, 
                       color='#DC3220', alpha=0.8)

ax_funnel.set_yticks(y_pos)
ax_funnel.set_yticklabels(stages, fontsize=FONT_SIZES['tick_label'])
ax_funnel.set_xlabel('Number of Organizations', fontsize=FONT_SIZES['axis_label'])
ax_funnel.set_title('Filtering Funnel: From All Organizations to Probable ML Estimates', 
                    fontsize=FONT_SIZES['title'], fontweight='medium')
ax_funnel.set_xlim(0, max(values) * 1.25)
ax_funnel.invert_yaxis()
ax_funnel.spines['top'].set_visible(False)
ax_funnel.spines['right'].set_visible(False)

plt.tight_layout()

if SAVE_OUTPUTS:
    fig_funnel.savefig(DATA_DIR / 'output' / 'ml_filtering_funnel.png', dpi=200, bbox_inches='tight')
    print(f"Saved: {DATA_DIR / 'output' / 'ml_filtering_funnel.png'}")

plt.close(fig_funnel)

# -----------------------------------------------------------------------------
# Plot 2: Stacked Bar by Region (4 categories)
# -----------------------------------------------------------------------------
fig_region, ax_region = plt.subplots(figsize=(12, 7))

region_col = 'Subregion' if 'Subregion' in df.columns else 'subregion'
region_cat = pd.crosstab(df[region_col], df['_category'])

col_order = ['Probable', 'Possible', 'Non-zero', 'Not Detected']
region_cat = region_cat[[c for c in col_order if c in region_cat.columns]]

region_cat['_total'] = region_cat.sum(axis=1)
region_cat = region_cat.sort_values('_total', ascending=True).drop(columns='_total')

region_cat.plot(kind='barh', stacked=True, ax=ax_region, 
                color=[COLORS['probable'], COLORS['possible'], 
                       COLORS['nonzero'], COLORS['not_detected']],
                edgecolor='white', linewidth=0.5)

ax_region.set_xlabel('Number of Organizations', fontsize=FONT_SIZES['axis_label'])
ax_region.set_ylabel('Subregion', fontsize=FONT_SIZES['axis_label'])
ax_region.set_title('ML Estimate Confidence by Subregion', fontsize=FONT_SIZES['title'], fontweight='medium')
ax_region.legend(title='Category', fontsize=FONT_SIZES['legend'], frameon=False, loc='lower right')

plt.tight_layout()

if SAVE_OUTPUTS:
    fig_region.savefig(DATA_DIR / 'output' / 'ml_confidence_by_region.png', dpi=200, bbox_inches='tight')
    print(f"Saved: {DATA_DIR / 'output' / 'ml_confidence_by_region.png'}")

plt.close(fig_region)

print(f"\nCategory breakdown:")
print(df['_category'].value_counts())

---
## Step 5: ML Talent Landscape Plot

Scatter plot showing ML staff count vs ML share (%) for all organizations.

**Confidence Categories** (based on statistical estimates):
- **Probable**: q10 > 0 — 80% CI excludes zero
- **Possible**: q50 > 0, q10 = 0 — Central estimate positive
- **Non-zero**: q90 > 0, q50 = 0 — Upper bound positive only
- **Not Detected**: All zeros

In [None]:
try:
    from adjustText import adjust_text
    HAS_ADJUSTTEXT = True
except ImportError:
    HAS_ADJUSTTEXT = False
    print("Note: Install adjustText for better label placement: pip install adjustText")


def create_landscape_plot(df_input, title_suffix="", x_max=100, figsize=(11, 7), max_labels=25, log_x=True):
    """
    Create ML talent landscape scatter plot.
    
    Design approach:
    - Clear visual hierarchy: Probable > Possible > Non-zero > Not Detected
    - Points colored by confidence category (based on q10/q50/q90)
    - Labels only for "Probable" organizations (high confidence)
    - Legend positioned outside plot to maximize data space
    
    Confidence Categories:
    - Probable: q10 > 0 (80% CI excludes zero)
    - Possible: q50 > 0, q10 = 0 (central estimate positive but uncertain)
    - Non-zero: q90 > 0, q50 = 0 (upper bound positive only)
    - Not Detected: all zeros (no ML signal)
    
    Args:
        df_input: DataFrame with organization data
        title_suffix: Optional suffix for plot title
        x_max: Maximum x-axis value (ML share %)
        log_x: Use logarithmic x-axis (default True)
        figsize: Figure size tuple
        max_labels: Maximum number of organization labels to show
    
    Returns:
        fig, ax, plot_df
    """
    # Prepare data - use pure probit if available, else adjusted synthetic
    q10_pure = pd.to_numeric(df_input['q10'], errors='coerce') if 'q10' in df_input.columns else pd.Series(np.nan, index=df_input.index)
    q50_pure = pd.to_numeric(df_input['q50'], errors='coerce') if 'q50' in df_input.columns else pd.Series(np.nan, index=df_input.index)
    q90_pure = pd.to_numeric(df_input['q90'], errors='coerce') if 'q90' in df_input.columns else pd.Series(np.nan, index=df_input.index)
    q10_synthetic = pd.to_numeric(df_input['adjusted_synthetic_q10'], errors='coerce')
    q50_synthetic = pd.to_numeric(df_input['adjusted_synthetic_q50'], errors='coerce')
    q90_synthetic = pd.to_numeric(df_input['adjusted_synthetic_q90'], errors='coerce')
    
    # Use pure probit when available (indicated by q50 not being NaN)
    ml_q10 = q10_pure.where(q50_pure.notna(), q10_synthetic)
    ml_q50 = q50_pure.where(q50_pure.notna(), q50_synthetic)
    ml_q90 = q90_pure.where(q50_pure.notna(), q90_synthetic)
    
    plot_df = pd.DataFrame({
        'org': df_input['organization_name'].astype(str),
        'ml_n': ml_q50,  # Use q50 for plotting position
        'ml_q10': ml_q10,
        'ml_q90': ml_q90,
        'emp': pd.to_numeric(df_input['total_headcount'], errors='coerce'),
        'used_pure_probit': q50_pure.notna()
    })
    
    # Calculate ML share percentage
    plot_df['ml_pct'] = (plot_df['ml_n'] / plot_df['emp']) * 100.0
    plot_df['ml_pct'] = plot_df['ml_pct'].clip(lower=0, upper=100)
    
    # Clean data
    plot_df = plot_df.replace([np.inf, -np.inf], np.nan).dropna(subset=['ml_n', 'emp', 'ml_pct'])
    plot_df = plot_df[(plot_df['ml_n'] >= 0) & (plot_df['emp'] > 0)]
    
    # Assign confidence categories
    plot_df['cluster'] = [assign_confidence_category(q10, q50, q90) 
                          for q10, q50, q90 in plot_df[['ml_q10', 'ml_n', 'ml_q90']].values]
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # -------------------------------------------------------------------------
    # Data points by confidence category
    # -------------------------------------------------------------------------
    marker_sizes = {
        'Probable': 80,
        'Possible': 64,
        'Non-zero': 48,
        'Not Detected': 36,
    }
    
    # Plot in order: Not Detected first (background), then others on top
    for cluster_name in ['Not Detected', 'Non-zero', 'Possible', 'Probable']:
        sub_df = plot_df[plot_df['cluster'] == cluster_name]
        if len(sub_df) == 0:
            continue
        
        is_highlight = cluster_name in ['Probable', 'Possible']
        
        ax.scatter(
            sub_df['ml_pct'], sub_df['ml_n'],
            s=marker_sizes[cluster_name],
            c=LANDSCAPE_PALETTE[cluster_name],
            marker=LANDSCAPE_MARKERS[cluster_name],
            alpha=0.85 if is_highlight else 0.45,
            edgecolors='white' if is_highlight else 'none',
            linewidths=0.6 if is_highlight else 0,
            label=cluster_name,
            zorder=3 if is_highlight else 2
        )
    
    # -------------------------------------------------------------------------
    # Labels for Probable organizations - placed in top margin with leader lines
    # Uses bin-packing algorithm to avoid overlaps
    # -------------------------------------------------------------------------
    labeled_df = plot_df[plot_df['cluster'] == 'Probable'].copy()
    labeled_df = labeled_df.sort_values('ml_n', ascending=False).head(max_labels)
    
    if len(labeled_df) > 0:
        # Sort by x-position (ml_pct) for left-to-right label placement
        labeled_df = labeled_df.sort_values('ml_pct').reset_index(drop=True)
        
        # Get axis limits for positioning
        y_min, y_max = ax.get_ylim()
        x_min, x_max = ax.get_xlim()
        
        # Label positioning parameters
        label_y_base = y_max * 0.6   # First row y position (in data coords, log scale)
        row_multiplier = 1.8          # Each row is this much higher (log scale)
        max_rows = 6                  # Maximum number of label rows
        
        # Estimate text width in log-scale data coordinates
        def estimate_label_width_log(text, font_size=7):
            # Width in log10 units (empirically tuned)
            char_width = 0.08  # log10 units per character
            return len(text) * char_width + 0.1  # Add padding
        
        def labels_overlap(x1, text1, x2, text2):
            # Check if two labels would overlap (in log10 space)
            log_x1 = np.log10(max(x1, 0.0001))
            log_x2 = np.log10(max(x2, 0.0001))
            
            w1 = estimate_label_width_log(text1)
            w2 = estimate_label_width_log(text2)
            
            # Check overlap: labels are centered, so half-width on each side
            left1, right1 = log_x1 - w1/2, log_x1 + w1/2
            left2, right2 = log_x2 - w2/2, log_x2 + w2/2
            
            # Overlap if intervals intersect
            return not (right1 < left2 or right2 < left1)
        
        # Bin-packing: assign each label to a row
        # Each row is a list of (x_pos, org_name, data_row)
        rows = []
        
        for _, data_row in labeled_df.iterrows():
            x_pos = data_row['ml_pct']
            org_name = data_row['org']
            placed = False
            
            # Try to place in existing row (check ALL labels in row for overlap)
            for row_labels in rows:
                can_place = True
                for existing_x, existing_name, _ in row_labels:
                    if labels_overlap(x_pos, org_name, existing_x, existing_name):
                        can_place = False
                        break
                
                if can_place:
                    row_labels.append((x_pos, org_name, data_row))
                    placed = True
                    break
            
            # If no existing row works, create new row (up to max_rows)
            if not placed:
                if len(rows) < max_rows:
                    rows.append([(x_pos, org_name, data_row)])
                # else: skip this label (too crowded)
        
        # Draw labels and leader lines
        for row_idx, row_labels in enumerate(rows):
            label_y = label_y_base * (row_multiplier ** row_idx)
            
            for x_pos, org_name, data_row in row_labels:
                point_x = data_row['ml_pct']
                point_y = data_row['ml_n']
                
                # Draw leader line from point to label
                ax.annotate(
                    org_name,
                    xy=(point_x, point_y),  # Point location
                    xytext=(x_pos, label_y),  # Label location  
                    fontsize=FONT_SIZES['org_label'],
                    ha='center', va='bottom',
                    color='#404040',
                    arrowprops=dict(
                        arrowstyle='-',
                        lw=0.5,
                        color='#909090',
                        alpha=0.6,
                        connectionstyle='arc3,rad=0'
                    ),
                    annotation_clip=False  # Allow drawing outside axes
                )
    
    # -------------------------------------------------------------------------
    # Axis formatting
    # -------------------------------------------------------------------------
    if log_x:
        ax.set_xscale('log')
        ax.set_xlim(0.001, x_max)  # Start at 0.001% for log scale
        # Clean x-axis ticks for log scale
        ax.set_xticks([0.001, 0.01, 0.1, 1, 10, 100])
        ax.set_xticklabels(['0.001%', '0.01%', '0.1%', '1%', '10%', '100%'])
    else:
        ax.set_xlim(0, x_max)
    format_log_axis(ax, axis='y', limits=(1, 50000))  # Extended to accommodate labels
    
    ax.set_xlabel('ML Share (%)', fontsize=FONT_SIZES['axis_label'])
    ax.set_ylabel('ML Staff Count (q50)', fontsize=FONT_SIZES['axis_label'])
    ax.set_title(f'ML Talent Landscape{title_suffix}', fontsize=FONT_SIZES['title'], fontweight='medium', pad=10)
    
    # -------------------------------------------------------------------------
    # Legend — positioned outside, right side
    # -------------------------------------------------------------------------
    cluster_counts = plot_df['cluster'].value_counts()
    
    legend_handles = []
    for cluster in ['Probable', 'Possible', 'Non-zero', 'Not Detected']:
        count = cluster_counts.get(cluster, 0)
        is_highlight = cluster in ['Probable', 'Possible']
        legend_handles.append(
            plt.scatter([], [], 
                s=marker_sizes[cluster] * 0.8,
                c=LANDSCAPE_PALETTE[cluster],
                marker=LANDSCAPE_MARKERS[cluster],
                alpha=0.85 if is_highlight else 0.45,
                edgecolors='white' if is_highlight else 'none',
                linewidths=0.6 if is_highlight else 0,
                label=f'{cluster} (n={count})'
            )
        )
    
    ax.legend(
        handles=legend_handles,
        loc='center left',
        bbox_to_anchor=(1.02, 0.5),
        frameon=True,
        framealpha=0.95,
        edgecolor='#E0E0E0',
        fontsize=FONT_SIZES['legend'],
    )
    
    # -------------------------------------------------------------------------
    # Grid and styling
    # -------------------------------------------------------------------------
    ax.grid(True, which='major', alpha=0.20, linewidth=0.4, color=COLORS['gridline'])
    ax.set_axisbelow(True)
    
    # Use fixed margins only - tight_layout can cause figure size explosion
    plt.subplots_adjust(left=0.08, right=0.78, top=0.92, bottom=0.08)
    
    # Print summary of which estimate was used
    n_pure = plot_df['used_pure_probit'].sum()
    n_synthetic = len(plot_df) - n_pure
    print(f"Estimate source: {n_pure} pure probit, {n_synthetic} adjusted synthetic")
    
    return fig, ax, plot_df

In [None]:
# Plot 1: All 403 organizations
fig_all, ax_all, plot_all = create_landscape_plot(
    df_all, 
    title_suffix=" — All Organizations",
    figsize=(11, 7),
    max_labels=25
)

print(f"\nCluster distribution (All Orgs, N={len(plot_all)}):")
for cluster in ['Probable', 'Possible', 'Non-zero', 'Not Detected']:
    count = (plot_all['cluster'] == cluster).sum()
    print(f"  {cluster}: {count}")

# Save to output directory
if SAVE_OUTPUTS:
    output_dir = DATA_DIR / 'output'
    output_dir.mkdir(exist_ok=True)
    fig_all.savefig(output_dir / 'ml_landscape_all.png', dpi=200)
    print(f"\nSaved: {output_dir / 'ml_landscape_all.png'}")

plt.close(fig_all)

In [None]:
# Landscape plots already saved above
print("All landscape plots saved to output directory")

---
## Step 6: Geographic Distribution

Visualize organization distribution by country and subregion:
- **World heat map**: Country frequency choropleth
- **Subregion bar chart**: Horizontal stacked bar chart by source category

In [None]:
import plotly.express as px
import plotly.graph_objects as go

# =============================================================================
# PLOTLY THEME — Match Academic/Scientific Style
# =============================================================================

PLOTLY_TEMPLATE = {
    'layout': {
        'font': {'family': 'Helvetica Neue, Helvetica, Arial, sans-serif', 'size': 12, 'color': '#333333'},
        'title': {'font': {'size': 14, 'color': '#333333'}, 'x': 0.5, 'xanchor': 'center'},
        'paper_bgcolor': 'white',
        'plot_bgcolor': 'white',
        'colorway': PALETTE_LIST,
        'margin': {'l': 60, 'r': 30, 't': 60, 'b': 60},
    }
}

# Custom sequential color scale matching our palette (blue-based)
SEQUENTIAL_COLORSCALE = [
    [0.0, '#F7FBFF'],   # Very light blue-white
    [0.2, '#DEEBF7'],   # Light blue
    [0.4, '#9ECAE1'],   # Medium light blue
    [0.6, '#4292C6'],   # Medium blue
    [0.8, '#2171B5'],   # Darker blue
    [1.0, '#084594'],   # Deep blue (close to our primary blue)
]


def create_country_heatmap(df_input, title_suffix=""):
    """
    Create a world choropleth map showing country frequency.
    
    Design: Clean, minimal choropleth with consistent typography
    and a sequential blue color scale.
    
    Args:
        df_input: DataFrame with 'Country' column
        title_suffix: Optional suffix for plot title
    
    Returns:
        fig: Plotly figure object
        country_counts: Series with country frequencies
    """
    # Count country frequencies
    country_counts = df_input['Country'].value_counts()
    
    title = f"Geographic Distribution{title_suffix}"
    
    fig = px.choropleth(
        locations=country_counts.index.tolist(),
        color=country_counts.values,
        locationmode="country names",
        color_continuous_scale=SEQUENTIAL_COLORSCALE,
        labels={"color": "Organizations"}
    )
    
    fig.update_layout(
        title={
            'text': title,
            'x': 0.5,
            'xanchor': 'center',
            'font': {'size': 14, 'family': 'Helvetica Neue, Helvetica, Arial, sans-serif', 'color': '#333333'}
        },
        font={'family': 'Helvetica Neue, Helvetica, Arial, sans-serif', 'size': 11, 'color': '#333333'},
        geo=dict(
            showframe=False,
            showcoastlines=True,
            coastlinecolor='#B0B0B0',
            coastlinewidth=0.5,
            showland=True,
            landcolor='#F8F8F8',
            showocean=True,
            oceancolor='#FAFAFA',
            showcountries=True,
            countrycolor='#D0D0D0',
            countrywidth=0.3,
            projection_type='natural earth',
        ),
        coloraxis_colorbar=dict(
            title=dict(text='Count', font={'size': 11}),
            tickfont={'size': 10},
            len=0.6,
            thickness=15,
            outlinewidth=0,
        ),
        margin={'l': 10, 'r': 10, 't': 50, 'b': 10},
        height=420,
        paper_bgcolor='white',
    )
    
    return fig, country_counts


def create_subregion_bar_chart(df_input, title_suffix=""):
    """
    Create a horizontal bar chart showing subregion counts.
    
    Design: Clean horizontal bars with consistent coloring,
    sorted by total count for easy comparison.
    
    Args:
        df_input: DataFrame with 'Subregion' and optionally 'Source' columns
        title_suffix: Optional suffix for plot title
    
    Returns:
        fig: Plotly figure object
        subregion_data: DataFrame with subregion counts
    """
    # Count by subregion
    subregion_counts = df_input['Subregion'].value_counts().sort_values(ascending=True)
    
    # Check if we should stack by Source
    use_stacked = 'Source' in df_input.columns
    
    if use_stacked:
        # Map Source to simplified categories
        df_plot = df_input.copy()
        df_plot['Source_Category'] = df_plot['Source'].apply(
            lambda x: "Manual Search + Network" if x == "Manual Search + Network" else "Other Sources"
        )
        
        # Cross-tab for stacked bar
        stacked_data = pd.crosstab(df_plot['Subregion'], df_plot['Source_Category'])
        stacked_data = stacked_data.reindex(stacked_data.sum(axis=1).sort_values(ascending=True).index)
        
        fig = go.Figure()
        
        # Use our palette colors
        colors = {'Other Sources': COLORS['primary'], 'Manual Search + Network': COLORS['secondary']}
        
        for col in stacked_data.columns:
            fig.add_trace(go.Bar(
                name=col,
                y=stacked_data.index.tolist(),
                x=stacked_data[col].values,
                orientation='h',
                marker_color=colors.get(col, PALETTE['gray']),
                marker_line_width=0,
            ))
        
        fig.update_layout(barmode='stack')
        subregion_data = stacked_data
    else:
        # Simple bar chart
        fig = go.Figure()
        fig.add_trace(go.Bar(
            y=subregion_counts.index.tolist(),
            x=subregion_counts.values,
            orientation='h',
            marker_color=COLORS['primary'],
            marker_line_width=0,
        ))
        subregion_data = subregion_counts.to_frame(name='Count')
    
    title = f"Organizations by Subregion{title_suffix}"
    
    # Build layout kwargs conditionally (legend only if stacked)
    layout_kwargs = dict(
        title={
            'text': title,
            'x': 0.5,
            'xanchor': 'center',
            'font': {'size': 14, 'family': 'Helvetica Neue, Helvetica, Arial, sans-serif', 'color': '#333333'}
        },
        font={'family': 'Helvetica Neue, Helvetica, Arial, sans-serif', 'size': 11, 'color': '#333333'},
        xaxis=dict(
            title=dict(text='Count', font={'size': 11}),
            tickfont={'size': 10},
            gridcolor='#E8E8E8',
            gridwidth=0.5,
            zeroline=True,
            zerolinecolor='#D0D0D0',
            zerolinewidth=0.8,
        ),
        yaxis=dict(
            title=dict(text='', font={'size': 11}),
            tickfont={'size': 10},
            automargin=True,
        ),
        margin={'l': 140, 'r': 30, 't': 70 if use_stacked else 50, 'b': 50},
        height=max(350, len(subregion_data) * 22),
        paper_bgcolor='white',
        plot_bgcolor='white',
    )
    
    if use_stacked:
        layout_kwargs['legend'] = dict(
            orientation='h',
            yanchor='bottom',
            y=1.02,
            xanchor='center',
            x=0.5,
            font={'size': 10},
        )
    
    fig.update_layout(**layout_kwargs)
    
    return fig, subregion_data

In [None]:
# Create geographic visualizations for All Organizations
fig_map_all, country_counts_all = create_country_heatmap(df_all, title_suffix=" — All Organizations")
fig_map_all.show()

print(f"\nTop 10 countries (All Organizations, N={len(df_all)}):")
for country, count in country_counts_all.head(10).items():
    print(f"  {country}: {count}")

In [None]:
# Subregion bar chart for All Organizations
fig_bar_all, subregion_data_all = create_subregion_bar_chart(df_all, title_suffix=" — All Organizations")
fig_bar_all.show()

print(f"\nSubregion totals (All Organizations):")
if hasattr(subregion_data_all, 'sum') and subregion_data_all.ndim > 1:
    totals = subregion_data_all.sum(axis=1).sort_values(ascending=False)
else:
    totals = subregion_data_all['Count'].sort_values(ascending=False) if 'Count' in subregion_data_all.columns else subregion_data_all.iloc[:, 0].sort_values(ascending=False)
    
for region, count in totals.head(10).items():
    print(f"  {region}: {count}")

In [None]:
# Optional: Save geographic plots
if SAVE_OUTPUTS:
    output_dir = DATA_DIR / 'output'
    
    fig_map_all.write_html(output_dir / 'country_map_all.html')
    fig_bar_all.write_html(output_dir / 'subregion_bar_all.html')
    
    print(f"Saved geographic plots to {output_dir}")

---
## Step 7: Methodology Overview

Four-panel visualization summarizing the estimation methodology:
- **Panel A**: Estimation pipeline flowchart
- **Panel B**: Correlation matrix showing agreement between estimators
- **Panel C**: Beta prior distribution for prevalence
- **Panel D**: Sensitivity vs Specificity trade-off for each estimator

In [None]:
import matplotlib.gridspec as gridspec
import scipy.stats as stats
import seaborn as sns

# --- LOCAL DESIGN CONFIGURATION (for this figure) ---
METHODOLOGY_COLORS = {
    'primary': '#3C5488',       # Deep Slate Blue (Enterprise)
    'accent_red': '#DC3220',    # Vermillion
    'secondary': '#009988',     # Teal (Boutique/GPT)
    'highlight': '#E68613',     # Amber (Gemini)
    'muted_violet': '#7B4B94',  # Mid-Scale
    'neutral': '#868686',       # Gray
    'gridline': '#E0E0E0',
    'text_dark': '#333333',
    'ci_pure_probit': '#2C6E49' # Forest Green (Probit)
}

METHODOLOGY_FONT_SIZES = {
    'title': 16,
    'axis_label': 11,
    'tick': 10,
    'legend': 9
}

def setup_methodology_axis(ax):
    """Applies design system spine and grid rules."""
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(0.8)
    ax.spines['bottom'].set_linewidth(0.8)
    ax.spines['left'].set_color(METHODOLOGY_COLORS['text_dark'])
    ax.spines['bottom'].set_color(METHODOLOGY_COLORS['text_dark'])
    ax.grid(True, which='major', axis='both', color=METHODOLOGY_COLORS['gridline'], 
            linestyle='-', linewidth=0.5, alpha=0.5)

# --- DATA PREPARATION ---

# Panel B: Dual Correlation Matrices (Pearson & Spearman)
corr_labels = ['fb', 'fs', 'fbs', 'gemini', 'sonnet', 'gpt']

# Pearson correlation (per-company prevalences)
pearson_data = np.array([
    [1.000, 0.334, 0.550, 0.326, 0.231, 0.324],
    [0.334, 1.000, 0.231, 0.495, 0.337, 0.469],
    [0.550, 0.231, 1.000, 0.206, 0.083, 0.165],
    [0.326, 0.495, 0.206, 1.000, 0.818, 0.942],
    [0.231, 0.337, 0.083, 0.818, 1.000, 0.896],
    [0.324, 0.469, 0.165, 0.942, 0.896, 1.000]
])

# Spearman rank correlation (robust to outliers)
spearman_data = np.array([
    [1.000, 0.521, 0.628, 0.444, 0.440, 0.486],
    [0.521, 1.000, 0.419, 0.544, 0.450, 0.521],
    [0.628, 0.419, 1.000, 0.308, 0.282, 0.351],
    [0.444, 0.544, 0.308, 1.000, 0.706, 0.853],
    [0.440, 0.450, 0.282, 0.706, 1.000, 0.804],
    [0.486, 0.521, 0.351, 0.853, 0.804, 1.000]
])

# Panel C: Size-dependent Beta Priors
x_beta = np.linspace(0, 0.5, 300)

size_priors = [
    {'name': '< 100 employees', 'alpha': 2.788, 'beta': 23.087, 'color': METHODOLOGY_COLORS['secondary'], 'mean': 0.108},
    {'name': '100 - 1K employees', 'alpha': 3.137, 'beta': 58.208, 'color': METHODOLOGY_COLORS['primary'], 'mean': 0.051},
    {'name': '1K - 10K employees', 'alpha': 2.442, 'beta': 192.854, 'color': METHODOLOGY_COLORS['muted_violet'], 'mean': 0.013},
    {'name': '> 10K employees', 'alpha': 1.896, 'beta': 474.957, 'color': METHODOLOGY_COLORS['highlight'], 'mean': 0.004},
]

# Panel D: Estimators (from validation data)
methodology_estimators = [
    {'name': 'Filter: Broad Yes', 'spec': 0.910, 'sens': 0.562, 'color': METHODOLOGY_COLORS['neutral'], 'marker': 'o', 'size': 70},
    {'name': 'Filter: Broad+Strict', 'spec': 0.981, 'sens': 0.301, 'color': METHODOLOGY_COLORS['neutral'], 'marker': 's', 'size': 70},
    {'name': 'Filter: Strict No', 'spec': 0.891, 'sens': 0.412, 'color': METHODOLOGY_COLORS['neutral'], 'marker': 'v', 'size': 70},
    {'name': 'Gemini 2.5 Flash', 'spec': 0.829, 'sens': 0.712, 'color': METHODOLOGY_COLORS['highlight'], 'marker': 'P', 'size': 100},
    {'name': 'GPT-5 Mini', 'spec': 0.868, 'sens': 0.725, 'color': METHODOLOGY_COLORS['secondary'], 'marker': '^', 'size': 100},
    {'name': 'Claude Sonnet 4', 'spec': 0.975, 'sens': 0.588, 'color': METHODOLOGY_COLORS['primary'], 'marker': 'D', 'size': 100},
    {'name': 'Correlated Probit', 'spec': 0.926, 'sens': 0.791, 'color': METHODOLOGY_COLORS['ci_pure_probit'], 'marker': '*', 'size': 250}
]

# --- PLOTTING ---

fig = plt.figure(figsize=(18, 10), dpi=150)

# Grid Configuration:
# 2 Rows, 3 Columns (1/3 each)
# Pipeline takes Col 0 (Rows 0-1) - 1/3 width
# Pearson + Spearman share Cols 1-2 (Row 0) - 2/3 width total
# Prior takes Col 1 (Row 1) - 1/3 width
# Performance takes Col 2 (Row 1) - 1/3 width
gs = gridspec.GridSpec(2, 3, width_ratios=[1, 1, 1], wspace=0.35, hspace=0.3)

# ==========================================
# PANEL A: ESTIMATION PIPELINE (Left Column)
# ==========================================
ax_pipeline = fig.add_subplot(gs[:, 0])  # Spans both rows
ax_pipeline.set_xlim(0, 100)
ax_pipeline.set_ylim(0, 100)
ax_pipeline.axis('off')
ax_pipeline.set_title("A. Estimation Pipeline", loc='left', 
                      fontsize=METHODOLOGY_FONT_SIZES['title'], fontweight='medium')

def draw_box(ax, x, y, w, h, text, color, subtext=""):
    rect = FancyBboxPatch((x, y), w, h, boxstyle="round,pad=0.5", 
                          linewidth=1.2, edgecolor=color, facecolor='white', zorder=2)
    ax.add_patch(rect)
    ax.text(x + w/2, y + h/2 + 3, text, ha='center', va='center', 
            fontsize=11, fontweight='bold', color=color, zorder=3, 
            wrap=True)
    if subtext:
        ax.text(x + w/2, y + h/2 - 2.5, subtext, ha='center', va='center', 
                fontsize=9, color=METHODOLOGY_COLORS['text_dark'], zorder=3,
                wrap=True, linespacing=1.1)
    return rect

def draw_arrow(ax, x1, y1, x2, y2):
    ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
                arrowprops=dict(arrowstyle="->", color=METHODOLOGY_COLORS['neutral'], lw=1.5, zorder=1))

# Vertical Flow Coordinates - adjusted to use full height
center_x = 50
step_height = 13  # Box height
box_width = 96    # Nearly full width
box_x = center_x - (box_width / 2)
gap = 4           # More padding between boxes
start_y = 98      # Start at top

# 1. Annotate
draw_box(ax_pipeline, box_x, start_y - step_height, box_width, step_height, 
         "1. Annotate Employees", METHODOLOGY_COLORS['primary'], "6 imperfect classifiers\n(3 filters + 3 LLMs) label\neach employee as ML or not")
draw_arrow(ax_pipeline, center_x, start_y - step_height, center_x, start_y - step_height - gap)

# 2. Split path (Params vs Synthetic)
split_width = 46  # Wider split boxes
y_split = start_y - step_height - gap - step_height
draw_box(ax_pipeline, center_x - split_width - 3, y_split, split_width, step_height, 
         "2. Est. Accuracy", METHODOLOGY_COLORS['secondary'], "From ground-truth labels:\nsens/spec per annotator\n+ error correlations")
draw_box(ax_pipeline, center_x + 3, y_split, split_width, step_height, 
         "3. Synthetic Gen", METHODOLOGY_COLORS['muted_violet'], "For aggregate-only firms:\nGaussian copula generates\nemployees w/ correlations")

# Merge arrows to Probit
draw_arrow(ax_pipeline, center_x - split_width/2 - 3, y_split, center_x - 5, y_split - gap)
draw_arrow(ax_pipeline, center_x + split_width/2 + 3, y_split, center_x + 5, y_split - gap)

# 4. Probit
y_probit = y_split - gap - step_height
draw_box(ax_pipeline, box_x, y_probit, box_width, step_height, 
         "4. Compute P(ML)", METHODOLOGY_COLORS['ci_pure_probit'], 
         "Probit combines 6 annotations\nw/ correlated errors + size-\ndependent prevalence prior")
draw_arrow(ax_pipeline, center_x, y_probit, center_x, y_probit - gap)

# 5. Bootstrap
y_boot = y_probit - gap - step_height
draw_box(ax_pipeline, box_x, y_boot, box_width, step_height, 
         "5. Bootstrap (1000x)", METHODOLOGY_COLORS['highlight'], 
         "Resample for uncertainty:\nmatrices, prior, sampling,\ncorrelations, realization")
draw_arrow(ax_pipeline, center_x, y_boot, center_x, y_boot - gap)

# 6. Output
y_final = y_boot - gap - step_height
draw_box(ax_pipeline, box_x, y_final, box_width, step_height, 
         "6. Aggregate Counts", METHODOLOGY_COLORS['text_dark'], "Sum P(ML) per company;\nbootstrap gives point\nestimates + 80% CI")

# ==========================================
# PANEL B: DUAL CORRELATION MATRICES (Top Right, spans 2/3)
# ==========================================
# Custom diverging colormap (white to dark red, with light blue for negatives)
cmap_corr = LinearSegmentedColormap.from_list('corr', ['#67a9cf', '#f7f7f7', '#8c510a', '#543005'], N=256)

# Create subgridspec for the two heatmaps within cols 1-2
gs_corr = gs[0, 1:].subgridspec(1, 2, wspace=0.4)

# B1: Pearson Correlation
ax1a = fig.add_subplot(gs_corr[0, 0])
sns.heatmap(pearson_data, ax=ax1a, cmap=cmap_corr, annot=True, fmt=".2f", 
            xticklabels=corr_labels, yticklabels=corr_labels,
            vmin=-0.4, vmax=1.0, cbar=False, annot_kws={"size": 9})
ax1a.set_title("B. Pearson Correlation\n(Per-Company Prevalences)", loc='center', 
               fontsize=METHODOLOGY_FONT_SIZES['title'], fontweight='medium')
ax1a.tick_params(axis='both', which='both', length=0)
plt.setp(ax1a.get_xticklabels(), rotation=45, ha='right', fontsize=9)
plt.setp(ax1a.get_yticklabels(), rotation=0, fontsize=9)

# B2: Spearman Rank Correlation
ax1b = fig.add_subplot(gs_corr[0, 1])
sns.heatmap(spearman_data, ax=ax1b, cmap=cmap_corr, annot=True, fmt=".2f", 
            xticklabels=corr_labels, yticklabels=corr_labels,
            vmin=-0.4, vmax=1.0, cbar=True, annot_kws={"size": 9},
            cbar_kws={'shrink': 0.8})
ax1b.set_title("Spearman Rank Correlation\n(Robust to Outliers)", loc='center', 
               fontsize=METHODOLOGY_FONT_SIZES['title'], fontweight='medium')
ax1b.tick_params(axis='both', which='both', length=0)
plt.setp(ax1b.get_xticklabels(), rotation=45, ha='right', fontsize=9)
plt.setp(ax1b.get_yticklabels(), rotation=0, fontsize=9)

# ==========================================
# PANEL C: SIZE-DEPENDENT BETA PRIORS (Bottom Middle)
# ==========================================
ax2 = fig.add_subplot(gs[1, 1])
setup_methodology_axis(ax2)

# Plot each size-dependent prior
for prior in size_priors:
    y_beta = stats.beta.pdf(x_beta, prior['alpha'], prior['beta'])
    ax2.plot(x_beta, y_beta, color=prior['color'], linewidth=2, 
             label=f"{prior['name']} (μ={prior['mean']:.1%})")
    # Add vertical line at mean
    ax2.axvline(prior['mean'], color=prior['color'], linestyle='--', linewidth=1, alpha=0.5)

ax2.set_title("C. Size-Dependent Prevalence Priors", loc='left', 
              fontsize=METHODOLOGY_FONT_SIZES['title'], fontweight='medium')
ax2.set_xlabel(r"ML Prevalence ($\pi$)", fontsize=METHODOLOGY_FONT_SIZES['axis_label'])
ax2.set_ylabel("Density", fontsize=METHODOLOGY_FONT_SIZES['axis_label'])
ax2.set_xlim(0, 0.25) 
ax2.set_ylim(0, 200)
ax2.legend(frameon=True, fontsize=8, loc='upper right', framealpha=0.95, 
           edgecolor=METHODOLOGY_COLORS['gridline'], title='Company Size', title_fontsize=8)

# ==========================================
# PANEL D: SENSITIVITY vs SPECIFICITY (Bottom Right)
# ==========================================
ax3 = fig.add_subplot(gs[1, 2])
setup_methodology_axis(ax3)

# Plot points
for est in methodology_estimators:
    ax3.scatter(est['spec'], est['sens'], 
                color=est['color'], 
                marker=est['marker'], 
                s=est['size'], 
                edgecolor='white' if est['marker'] != '*' else 'none',
                linewidth=0.8,
                label=est['name'],
                zorder=10,
                alpha=0.9)

# Styling
ax3.set_title('D. Estimator Performance', loc='left', 
              fontsize=METHODOLOGY_FONT_SIZES['title'], fontweight='medium')
ax3.set_xlabel('Specificity (True Negative Rate)', fontsize=METHODOLOGY_FONT_SIZES['axis_label'])
ax3.set_ylabel('Sensitivity (True Positive Rate)', fontsize=METHODOLOGY_FONT_SIZES['axis_label'])
ax3.set_xlim(0.0, 1.02)
ax3.set_ylim(0.0, 1.02) 

# Shaded regions: highlight 0.8-1.0 bands on both axes
# Vertical band: high specificity (0.8-1.0 on x-axis, full height)
rect_spec = MplRectangle((0.8, 0.0), 0.22, 1.02, linewidth=0, edgecolor='none', 
                          facecolor='#3C5488', alpha=0.12)
ax3.add_patch(rect_spec)
# Horizontal band: high sensitivity (0.8-1.0 on y-axis, full width)
rect_sens = MplRectangle((0.0, 0.8), 1.02, 0.22, linewidth=0, edgecolor='none', 
                          facecolor='#2C6E49', alpha=0.12)
ax3.add_patch(rect_sens)

# Custom Legend - move Probit to top
handles, labels = ax3.get_legend_handles_labels()
handles.insert(0, handles.pop())
labels.insert(0, labels.pop())
ax3.legend(handles, labels, loc='lower left', frameon=True, 
           fontsize=METHODOLOGY_FONT_SIZES['legend'], framealpha=0.95, 
           edgecolor=METHODOLOGY_COLORS['gridline'])

plt.tight_layout()

# Save to output directory
if SAVE_OUTPUTS:
    output_path = DATA_DIR / 'output' / 'methodology_overview.png'
    fig.savefig(output_path, dpi=200, bbox_inches='tight', facecolor='white')
    print(f"Saved: {output_path}")

plt.show()