# Q4: Geographic Distribution & Regional Specialization

## Research Questions

1. **Distribution:** Where are clinical trials registered, and how concentrated is the market?
2. **Specialization:** Do countries show relative over-/under-representation in therapeutic areas?
3. **Trends:** Has the geographic distribution shifted over time?

## Structure

| Section | Question | Approach |
|---------|----------|----------|
| 2. Distribution | Market concentration | HHI with bootstrap CI |
| 3. Specialization | Regional patterns | Location Quotient (descriptive, no CI) |
| 4. Temporal | Distribution shifts | JSD, share changes |
| Appendix A | Site counts | Descriptive |

## Critical Caveats

**This analysis is exploratory/descriptive, not inferential.**

1. **Condition labels are not normalized.** The registry uses free-text condition names; "breast cancer", "breast neoplasm", and "carcinoma of breast" appear as separate conditions. Without deduplication, LQ values measure registration patterns in fragmented labels, not true specialization.

2. **Multi-condition trials inflate LQ denominators.** A trial with 3 conditions contributes 3× to condition totals. High LQ for conditions that frequently co-occur may be artifacts.

3. **Multinational trials assigned to single country.** ~20–35% of trials have sites in multiple countries but are assigned to one "primary" country (mode of sites). This inflates concentration for top markets.

4. **Trial counts ≠ enrollment capacity.** A high LQ for India in oncology means India registers many oncology trials relative to its portfolio—not that India has proportional enrollment capacity. Site selection requires enrollment rate data.

5. **Temporal analysis reflects registry coverage, not only trial activity.** FDAAA 2007 mandated US registration; pre-2007 data is US-biased. Non-US registries (EUCTR, ChiCTR) grew over time.

## Scope

- **Data:** ClinicalTrials.gov, start year 1990–2025
- **Primary country:** Mode of site locations; ties broken arbitrarily (first in SQL sort)
- **LQ interpretation:** Over-representation relative to global average, not validated expertise

In [1]:
# ============================================================
# Setup
# ============================================================

import sys
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.stats import chi2_contingency, mannwhitneyu, spearmanr
from IPython.display import display, Markdown
import plotly.graph_objects as go

# Project root for imports
PROJECT_ROOT = Path('..')
sys.path.insert(0, str(PROJECT_ROOT))

# Shared utilities
from src.data.loader import load_sql_query, get_db_connection
from src.analysis.viz import DEFAULT_COLORS, create_horizontal_bar_chart
from src.analysis.metrics import calc_cramers_v, interpret_effect_size
from src.analysis.constants import PHASE_ORDER_CLINICAL, COHORT_BINS, COHORT_LABELS

# Paths (validated at setup)
DB_PATH = PROJECT_ROOT / 'data' / 'database' / 'clinical_trials.db'
SQL_PATH = PROJECT_ROOT / 'sql' / 'queries'
assert DB_PATH.exists(), f"DB not found: {DB_PATH}"
assert SQL_PATH.exists(), f"SQL folder not found: {SQL_PATH}"

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

In [2]:
# ============================================================
# Database connection
# ============================================================

conn = get_db_connection(DB_PATH)

---

## 1. Data Loading & Validation

In [None]:
# ============================================================
# 1.1 Load ABT (study-level with geographic features)
# ============================================================

df_abt = load_sql_query('q4_abt.sql', conn, SQL_PATH)

n_studies = len(df_abt)
n_with_location = df_abt['has_location_data'].sum()
pct_location = n_with_location / n_studies * 100
year_min, year_max = df_abt['start_year'].min(), df_abt['start_year'].max()

display(Markdown(f"""
**ABT:** {n_studies:,} studies ({year_min}–{year_max}). Location data: {n_with_location:,} ({pct_location:.0f}%).
"""))

# Filter to studies with location data
df_geo = df_abt[df_abt['has_location_data'] == 1].copy()
n_geo = len(df_geo)

In [4]:
# ============================================================
# 1.2 Geographic summary statistics
# ============================================================

# Unique countries
n_countries = df_geo['primary_country'].nunique()

# Site complexity distribution
n_single_site = df_geo['is_single_site'].sum()
n_multinational = df_geo['is_multinational'].sum()
n_large_multisite = df_geo['is_large_multisite'].sum()

pct_single = n_single_site / n_geo * 100
pct_multinational = n_multinational / n_geo * 100
pct_large = n_large_multisite / n_geo * 100

# Site count statistics
median_sites = df_geo['n_sites'].median()
mean_sites = df_geo['n_sites'].mean()
max_sites = df_geo['n_sites'].max()
q75_sites = df_geo['n_sites'].quantile(0.75)

display(Markdown(f"""
**Geographic summary (n = {n_geo:,} studies with location data):**

| Metric | Value |
|--------|-------|
| Unique countries | {n_countries:,} |
| Single-site trials | {n_single_site:,} ({pct_single:.1f}%) |
| Multinational trials | {n_multinational:,} ({pct_multinational:.1f}%) |
| Large multi-site (≥10 sites) | {n_large_multisite:,} ({pct_large:.1f}%) |
| Median sites per trial | {median_sites:.0f} |
"""))


**Geographic summary (n = 88,404 studies with location data):**

| Metric | Value |
|--------|-------|
| Unique countries | 175 |
| Single-site trials | 64,557 (73.0%) |
| Multinational trials | 7,196 (8.1%) |
| Large multi-site (≥10 sites) | 8,959 (10.1%) |
| Median sites per trial | 1 |


---

## 2. Geographic Distribution

**Question:** Where are clinical trials conducted?

In [None]:
# ============================================================
# 2.1 Country-level distribution
# ============================================================

# Aggregate by primary country
df_country = (
    df_geo
    .groupby('primary_country')
    .agg(
        n_trials=('study_id', 'nunique'),
        n_interventional=('is_interventional', 'sum'),
        n_industry=('is_industry_sponsor', 'sum'),
        median_sites=('n_sites', 'median'),
        pct_multinational=('is_multinational', 'mean'),
    )
    .reset_index()
    .sort_values('n_trials', ascending=False)
)
df_country['pct_multinational'] = df_country['pct_multinational'] * 100
df_country['pct_interventional'] = df_country['n_interventional'] / df_country['n_trials'] * 100
df_country['pct_industry'] = df_country['n_industry'] / df_country['n_trials'] * 100

# Top 20 countries
top20 = df_country.head(20).copy()

# Concentration metrics (trial-weighted, primary country assignment)
total_trials = df_country['n_trials'].sum()
top5_share = df_country.head(5)['n_trials'].sum() / total_trials * 100
top10_share = df_country.head(10)['n_trials'].sum() / total_trials * 100

# HHI (Herfindahl-Hirschman Index) with bootstrap CI
df_country['market_share'] = df_country['n_trials'] / total_trials
hhi = (df_country['market_share'] ** 2).sum() * 10000

# Bootstrap confidence interval for HHI
np.random.seed(42)
n_bootstrap = 1000
country_array = df_geo['primary_country'].values
hhi_bootstrap = []

for _ in range(n_bootstrap):
    sample_idx = np.random.choice(len(country_array), size=len(country_array), replace=True)
    sample_countries = country_array[sample_idx]
    counts = pd.Series(sample_countries).value_counts()
    shares = counts / counts.sum()
    hhi_boot = (shares ** 2).sum() * 10000
    hhi_bootstrap.append(hhi_boot)

hhi_ci_low = np.percentile(hhi_bootstrap, 2.5)
hhi_ci_high = np.percentile(hhi_bootstrap, 97.5)

# HHI interpretation (thresholds: <1500 unconcentrated, 1500-2500 moderate, >2500 high)
if hhi < 1500:
    hhi_interpretation = "unconcentrated"
elif hhi < 2500:
    hhi_interpretation = "moderately concentrated"
else:
    hhi_interpretation = "highly concentrated"

ci_spans_boundary = (hhi_ci_low < 1500 < hhi_ci_high) or (hhi_ci_low < 2500 < hhi_ci_high)
boundary_note = " (CI spans threshold)" if ci_spans_boundary else ""

pct_multinational_all = df_geo['is_multinational'].mean() * 100

display(Markdown(f"""
### 2.1 Geographic concentration (primary country assignment)

| Metric | Value | 95% CI |
|--------|-------|--------|
| Top 5 countries | {top5_share:.0f}% | — |
| Top 10 countries | {top10_share:.0f}% | — |
| **HHI** | **{hhi:.0f}** ({hhi_interpretation}) | [{hhi_ci_low:.0f}, {hhi_ci_high:.0f}]{boundary_note} |

**Interpretation caveat:** {pct_multinational_all:.0f}% of trials are multinational but assigned to a single country (mode of sites). This inflates top-market shares and HHI. See §2.1.2 for site-weighted alternative.
"""))

display(Markdown("**Top 20 countries by trial count:**"))
display(
    top20[['primary_country', 'n_trials', 'pct_interventional', 'pct_industry', 'median_sites']]
    .rename(columns={
        'primary_country': 'Country',
        'n_trials': 'Trials',
        'pct_interventional': 'Interventional %',
        'pct_industry': 'Industry %',
        'median_sites': 'Median Sites',
    })
    .style.format({
        'Trials': '{:,.0f}',
        'Interventional %': '{:.1f}%',
        'Industry %': '{:.1f}%',
        'Median Sites': '{:.0f}',
    }).hide(axis='index')
)

In [6]:
# ============================================================
# 2.1.1 Sensitivity: Single-country trials only
# ============================================================

df_single_country = df_geo[df_geo['is_multinational'] == 0].copy()
n_single = len(df_single_country)

single_country_counts = df_single_country['primary_country'].value_counts()
single_total = single_country_counts.sum()
single_shares = single_country_counts / single_total
hhi_single = (single_shares ** 2).sum() * 10000
top5_single_share = single_country_counts.head(5).sum() / single_total * 100

if hhi_single < 1500:
    hhi_single_interp = "unconcentrated"
elif hhi_single < 2500:
    hhi_single_interp = "moderately concentrated"
else:
    hhi_single_interp = "highly concentrated"

hhi_delta = hhi - hhi_single

# Bootstrap CI for delta
np.random.seed(43)
single_country_array = df_single_country['primary_country'].values
hhi_delta_bootstrap = []

for _ in range(n_bootstrap):
    sample_all = country_array[np.random.choice(len(country_array), len(country_array), replace=True)]
    shares_all = pd.Series(sample_all).value_counts() / len(sample_all)
    hhi_all_boot = (shares_all ** 2).sum() * 10000
    
    sample_single = single_country_array[np.random.choice(len(single_country_array), len(single_country_array), replace=True)]
    shares_single = pd.Series(sample_single).value_counts() / len(sample_single)
    hhi_single_boot = (shares_single ** 2).sum() * 10000
    
    hhi_delta_bootstrap.append(hhi_all_boot - hhi_single_boot)

hhi_delta_ci_low = np.percentile(hhi_delta_bootstrap, 2.5)
hhi_delta_ci_high = np.percentile(hhi_delta_bootstrap, 97.5)
delta_significant = not (hhi_delta_ci_low <= 0 <= hhi_delta_ci_high)

display(Markdown(f"""
### 2.1.1 Sensitivity: Single-country trials only

| Metric | All | Single-country | Δ [95% CI] |
|--------|-----|----------------|------------|
| N | {n_geo:,} | {n_single:,} | — |
| HHI | {hhi:.0f} | {hhi_single:.0f} | {hhi_delta:+.0f} [{hhi_delta_ci_low:+.0f}, {hhi_delta_ci_high:+.0f}] |

ΔHHI {'excludes' if delta_significant else 'includes'} zero. Classification {'changes' if hhi_interpretation != hhi_single_interp else 'unchanged'} ({hhi_interpretation} → {hhi_single_interp}).
"""))


### 2.1.1 Sensitivity: Single-country trials only

| Metric | All | Single-country | Δ [95% CI] |
|--------|-----|----------------|------------|
| N | 88,404 | 81,208 | — |
| HHI | 1504 | 1423 | +82 [+52, +108] |

ΔHHI excludes zero. Classification changes (moderately concentrated → unconcentrated).


In [None]:
# ============================================================
# 2.1.2 Site-weighted concentration (fractional assignment)
# ============================================================

# Each site contributes 1/n_sites to its country
query_site_weighted = """
WITH site_contributions AS (
    SELECT 
        l.study_id,
        l.country,
        1.0 / COUNT(*) OVER (PARTITION BY l.study_id) AS fractional_weight
    FROM locations l
    JOIN v_studies_clean s ON l.study_id = s.study_id
    WHERE l.country IS NOT NULL AND l.country != ''
      AND s.is_start_year_in_scope = 1
)
SELECT 
    country,
    SUM(fractional_weight) AS weighted_trials,
    COUNT(DISTINCT study_id) AS n_trials_any_site
FROM site_contributions
GROUP BY country
ORDER BY weighted_trials DESC
"""

df_site_weighted = pd.read_sql(query_site_weighted, conn)

# Calculate site-weighted HHI
total_weighted = df_site_weighted['weighted_trials'].sum()
df_site_weighted['market_share'] = df_site_weighted['weighted_trials'] / total_weighted
hhi_site_weighted = (df_site_weighted['market_share'] ** 2).sum() * 10000

top5_site_share = df_site_weighted.head(5)['weighted_trials'].sum() / total_weighted * 100
top10_site_share = df_site_weighted.head(10)['weighted_trials'].sum() / total_weighted * 100

# Bootstrap CI for site-weighted HHI
np.random.seed(45)
hhi_site_bootstrap = []

# Get site-level data for bootstrap
df_sites_boot = pd.read_sql("""
    SELECT l.study_id, l.country
    FROM locations l
    JOIN v_studies_clean s ON l.study_id = s.study_id
    WHERE l.country IS NOT NULL AND l.country != ''
      AND s.is_start_year_in_scope = 1
""", conn)

for _ in range(n_bootstrap):
    # Sample studies with replacement
    unique_studies = df_sites_boot['study_id'].unique()
    sampled_studies = np.random.choice(unique_studies, size=len(unique_studies), replace=True)
    boot_data = df_sites_boot[df_sites_boot['study_id'].isin(sampled_studies)].copy()
    
    # Calculate fractional weights per study
    site_counts = boot_data.groupby('study_id').size()
    boot_data['weight'] = boot_data['study_id'].map(lambda x: 1.0 / site_counts.get(x, 1))
    
    # Aggregate by country
    country_weighted = boot_data.groupby('country')['weight'].sum()
    shares = country_weighted / country_weighted.sum()
    hhi_boot = (shares ** 2).sum() * 10000
    hhi_site_bootstrap.append(hhi_boot)

hhi_site_ci_low = np.percentile(hhi_site_bootstrap, 2.5)
hhi_site_ci_high = np.percentile(hhi_site_bootstrap, 97.5)

if hhi_site_weighted < 1500:
    hhi_site_interp = "unconcentrated"
elif hhi_site_weighted < 2500:
    hhi_site_interp = "moderately concentrated"
else:
    hhi_site_interp = "highly concentrated"

hhi_delta_method = hhi - hhi_site_weighted
top5_delta = top5_share - top5_site_share
classification_change = hhi_interpretation != hhi_site_interp

display(Markdown(f"""
### 2.1.2 Site-weighted concentration

A trial with sites in 3 countries gives 1/3 credit to each, rather than assigning the full trial to one country.

| Metric | Primary Country | Site-Weighted | Δ |
|--------|-----------------|---------------|---|
| Top 5 | {top5_share:.0f}% | {top5_site_share:.0f}% | {top5_delta:+.0f}pp |
| Top 10 | {top10_share:.0f}% | {top10_site_share:.0f}% | {top10_share - top10_site_share:+.0f}pp |
| **HHI** | **{hhi:.0f}** [{hhi_ci_low:.0f}, {hhi_ci_high:.0f}] | **{hhi_site_weighted:.0f}** [{hhi_site_ci_low:.0f}, {hhi_site_ci_high:.0f}] | **{hhi_delta_method:+.0f}** |
| Classification | {hhi_interpretation} | {hhi_site_interp} | {'differs' if classification_change else 'consistent'} |

**Limitation:** Site-weighted counts participation, not capacity. Neither method captures enrollment volume.
"""))

In [7]:
# ============================================================
# 2.2 Country distribution chart
# ============================================================

# Prepare data for plot (sorted ascending for horizontal bar - largest at top)
plot_data = top20.sort_values('n_trials', ascending=True).tail(15)

# Key values for subtitle
top1 = df_country.iloc[0]
top2 = df_country.iloc[1]

fig_country = create_horizontal_bar_chart(
    data=plot_data,
    value_col='n_trials',
    label_col='primary_country',
    title='Top 15 Countries by Trial Count',
    subtitle=f"{top1['primary_country']}: {top1['n_trials']:,.0f} | {top2['primary_country']}: {top2['n_trials']:,.0f}",
    show_pct=False,
    height=500,
)
fig_country.show()

---

## 3. Regional Specialization

**Question:** Do countries show relative specialization in particular therapeutic areas?

**Approach:** Use Location Quotient (LQ) to measure over-/under-representation of conditions by country.

$$LQ = \frac{\text{share of country's trials in condition}}{\text{share of global trials in condition}}$$

- **LQ > 1.5**: Country relatively specialized (over-represented)
- **LQ ≈ 1**: Country matches global distribution

In [None]:
# ============================================================
# 3.1 Load country × condition data for specialization analysis
# ============================================================

df_spec = load_sql_query('q4_country_condition.sql', conn, SQL_PATH)

n_combinations = len(df_spec)
n_countries_spec = df_spec['country'].nunique()
n_conditions_spec = df_spec['condition_standardized'].nunique()

# Quantify multi-condition trials
median_conds_per_trial = df_geo['n_conditions'].median()
mean_conds_per_trial = df_geo['n_conditions'].mean()
pct_multi_condition = (df_geo['n_conditions'] > 1).mean() * 100

# Quantify condition fragmentation
top_conditions = df_spec.groupby('condition_standardized')['n_trials'].sum().nlargest(50)
top50_trials = top_conditions.sum()
total_condition_trials = df_spec.groupby('condition_standardized')['n_trials'].sum().sum()
top50_coverage = top50_trials / total_condition_trials * 100

condition_prefixes = df_spec['condition_standardized'].str[:10].value_counts()
n_potential_synonyms = (condition_prefixes > 1).sum()
pct_fragmented = n_potential_synonyms / len(condition_prefixes) * 100

display(Markdown(f"""
### 3.1 Specialization data

{n_combinations:,} country × condition pairs from {n_countries_spec} countries and {n_conditions_spec:,} conditions.

**Filters:** Countries ≥50 trials, conditions ≥100 global trials, pairs ≥5 trials.

| Characteristic | Value |
|----------------|-------|
| Multi-condition trials | {pct_multi_condition:.0f}% |
| Top-50 conditions coverage | {top50_coverage:.0f}% |
| Prefix fragmentation (synonyms) | ~{pct_fragmented:.0f}% |

*Multi-condition trials inflate LQ denominators. LQ does not control for phase or sponsor. See §3.2.1–3.2.3 for sensitivity analyses.*
"""))

In [None]:
# ============================================================
# 3.2 Identify apparent specializations (LQ > 1.5)
# ============================================================

df_high_lq = df_spec[df_spec['location_quotient'] > 1.5].copy()
df_high_lq = df_high_lq.sort_values('location_quotient', ascending=False)

major_countries = ['United States', 'China', 'Germany', 'France', 'United Kingdom', 'Japan', 'Canada', 'Italy', 'Spain', 'India']

specializations = []
for country in major_countries:
    country_data = df_high_lq[df_high_lq['country'] == country].head(3)
    if len(country_data) > 0:
        specs = []
        for _, row in country_data.iterrows():
            lq = row['location_quotient']
            n = row['n_trials']
            specs.append(f"{row['condition_standardized'][:30]} (LQ={lq:.1f}, n={n:,.0f})")
        specializations.append({
            'Country': country,
            'Apparent Specializations (LQ > 1.5)': ', '.join(specs),
        })

df_spec_summary = pd.DataFrame(specializations)

display(Markdown("### 3.2 Apparent specializations by country"))
display(Markdown("*LQ > 1.5 indicates over-representation, not expertise. See validations below.*"))
display(df_spec_summary.style.hide(axis='index'))

# ============================================================
# 3.2.1 Single-condition sensitivity: impact of multi-condition trials
# ============================================================

display(Markdown("### 3.2.1 Multi-condition inflation sensitivity"))

# Quantify multi-condition impact
multi_cond_trials = df_geo[df_geo['n_conditions'] > 1]
single_cond_trials = df_geo[df_geo['n_conditions'] == 1]
n_multi = len(multi_cond_trials)
n_single = len(single_cond_trials)
pct_multi = n_multi / len(df_geo) * 100

# Query LQ using only single-condition trials (no denominator inflation)
query_single_cond_lq = """
WITH 
single_cond_studies AS (
    SELECT c.study_id, LOWER(TRIM(c.condition_name)) AS condition_standardized
    FROM conditions c
    JOIN v_studies_clean s ON c.study_id = s.study_id
    WHERE c.condition_name IS NOT NULL AND TRIM(c.condition_name) != ''
      AND s.is_start_year_in_scope = 1
    GROUP BY c.study_id
    HAVING COUNT(DISTINCT LOWER(TRIM(c.condition_name))) = 1
),
single_study_countries AS (
    SELECT DISTINCT l.study_id, l.country, sc.condition_standardized
    FROM locations l
    JOIN single_cond_studies sc ON l.study_id = sc.study_id
    JOIN v_studies_clean s ON l.study_id = s.study_id
    WHERE l.country IS NOT NULL AND l.country != ''
      AND s.is_start_year_in_scope = 1
),
single_country_condition AS (
    SELECT country, condition_standardized, COUNT(DISTINCT study_id) AS n_trials
    FROM single_study_countries GROUP BY country, condition_standardized
),
single_country_totals AS (
    SELECT country, COUNT(DISTINCT study_id) AS country_total
    FROM single_study_countries GROUP BY country
),
single_condition_totals AS (
    SELECT condition_standardized, COUNT(DISTINCT study_id) AS condition_total
    FROM single_study_countries GROUP BY condition_standardized
),
single_global AS (
    SELECT COUNT(DISTINCT study_id) AS global_total FROM single_study_countries
)
SELECT 
    scc.country, scc.condition_standardized, scc.n_trials,
    sct.country_total, scond.condition_total, sg.global_total,
    ROUND((CAST(scc.n_trials AS REAL) / sct.country_total) / 
          (CAST(scond.condition_total AS REAL) / sg.global_total), 3) AS lq_single
FROM single_country_condition scc
JOIN single_country_totals sct ON scc.country = sct.country
JOIN single_condition_totals scond ON scc.condition_standardized = scond.condition_standardized
CROSS JOIN single_global sg
WHERE sct.country_total >= 30 AND scond.condition_total >= 50 AND scc.n_trials >= 3
"""

df_single_lq = pd.read_sql(query_single_cond_lq, conn)

# Compare top specializations: all-trial LQ vs single-condition LQ
single_validation = []
for country in major_countries[:5]:
    country_top_full = df_high_lq[df_high_lq['country'] == country].head(1)
    if len(country_top_full) == 0:
        continue
    
    top_cond = country_top_full.iloc[0]['condition_standardized']
    full_lq = country_top_full.iloc[0]['location_quotient']
    full_n = country_top_full.iloc[0]['n_trials']
    
    single_match = df_single_lq[
        (df_single_lq['country'] == country) & 
        (df_single_lq['condition_standardized'] == top_cond)
    ]
    
    if len(single_match) > 0:
        single_lq = single_match.iloc[0]['lq_single']
        single_n = single_match.iloc[0]['n_trials']
        lq_delta = single_lq - full_lq
        lq_delta_pct = (lq_delta / full_lq * 100) if full_lq > 0 else 0
        
        single_validation.append({
            'Country': country,
            'Condition': top_cond[:22] + '...' if len(top_cond) > 22 else top_cond,
            'All LQ': f"{full_lq:.2f}",
            'Single-cond LQ': f"{single_lq:.2f}",
            'Δ': f"{lq_delta:+.2f} ({lq_delta_pct:+.0f}%)",
        })
    else:
        single_validation.append({
            'Country': country,
            'Condition': top_cond[:22] + '...' if len(top_cond) > 22 else top_cond,
            'All LQ': f"{full_lq:.2f}",
            'Single-cond LQ': '—',
            'Δ': 'insufficient',
        })

if single_validation:
    df_single_val = pd.DataFrame(single_validation)
    display(df_single_val.style.hide(axis='index'))

display(Markdown(f"""
Multi-condition trials = {pct_multi:.0f}% of data. Single-condition LQ excludes these, removing denominator inflation. Large Δ (>±20%) suggests the specialization may be driven by co-occurring conditions rather than primary focus.
"""))

# ============================================================
# 3.2.2 Phase-stratified LQ sensitivity
# ============================================================

display(Markdown("### 3.2.2 Phase-stratified LQ (Phase 3 only)"))
display(Markdown("*Checks if specialization persists when controlling for phase mix.*"))

query_phase3_lq = """
WITH 
p3_study_countries AS (
    SELECT DISTINCT l.study_id, l.country
    FROM locations l
    JOIN v_studies_clean s ON l.study_id = s.study_id
    WHERE l.country IS NOT NULL AND l.country != ''
      AND s.is_start_year_in_scope = 1
      AND s.phase = 'PHASE3'
),
p3_study_conditions AS (
    SELECT DISTINCT c.study_id, LOWER(TRIM(c.condition_name)) AS condition_standardized
    FROM conditions c
    JOIN v_studies_clean s ON c.study_id = s.study_id
    WHERE c.condition_name IS NOT NULL AND TRIM(c.condition_name) != ''
      AND s.is_start_year_in_scope = 1
      AND s.phase = 'PHASE3'
),
p3_country_condition AS (
    SELECT sc.country, scond.condition_standardized, COUNT(DISTINCT sc.study_id) AS n_trials
    FROM p3_study_countries sc
    JOIN p3_study_conditions scond ON sc.study_id = scond.study_id
    GROUP BY sc.country, scond.condition_standardized
),
p3_country_totals AS (
    SELECT country, COUNT(DISTINCT study_id) AS country_total
    FROM p3_study_countries GROUP BY country
),
p3_condition_totals AS (
    SELECT condition_standardized, COUNT(DISTINCT study_id) AS condition_total
    FROM p3_study_conditions GROUP BY condition_standardized
),
p3_global AS (
    SELECT COUNT(DISTINCT sc.study_id) AS global_total
    FROM p3_study_countries sc
    JOIN p3_study_conditions scond ON sc.study_id = scond.study_id
)
SELECT 
    pcc.country, pcc.condition_standardized, pcc.n_trials,
    pct.country_total, pcond.condition_total, pg.global_total,
    ROUND((CAST(pcc.n_trials AS REAL) / pct.country_total) / 
          (CAST(pcond.condition_total AS REAL) / pg.global_total), 3) AS lq_p3
FROM p3_country_condition pcc
JOIN p3_country_totals pct ON pcc.country = pct.country
JOIN p3_condition_totals pcond ON pcc.condition_standardized = pcond.condition_standardized
CROSS JOIN p3_global pg
WHERE pct.country_total >= 20 AND pcond.condition_total >= 30 AND pcc.n_trials >= 3
"""

df_p3_spec = pd.read_sql(query_phase3_lq, conn)

phase_validation = []
for country in major_countries[:5]:
    country_top_full = df_high_lq[df_high_lq['country'] == country].head(1)
    if len(country_top_full) == 0:
        continue
    
    top_cond = country_top_full.iloc[0]['condition_standardized']
    full_lq = country_top_full.iloc[0]['location_quotient']
    
    p3_match = df_p3_spec[
        (df_p3_spec['country'] == country) & 
        (df_p3_spec['condition_standardized'] == top_cond)
    ]
    
    if len(p3_match) > 0:
        p3_lq = p3_match.iloc[0]['lq_p3']
        p3_n = p3_match.iloc[0]['n_trials']
        lq_delta = p3_lq - full_lq
        
        phase_validation.append({
            'Country': country,
            'Condition': top_cond[:22] + '...' if len(top_cond) > 22 else top_cond,
            'All-phase LQ': f"{full_lq:.2f}",
            'Phase 3 LQ': f"{p3_lq:.2f}",
            'Δ': f"{lq_delta:+.2f}",
            'P3 n': f"{p3_n:,}",
        })
    else:
        phase_validation.append({
            'Country': country,
            'Condition': top_cond[:22] + '...' if len(top_cond) > 22 else top_cond,
            'All-phase LQ': f"{full_lq:.2f}",
            'Phase 3 LQ': '—',
            'Δ': 'insufficient',
            'P3 n': '—',
        })

if phase_validation:
    df_phase_val = pd.DataFrame(phase_validation)
    display(df_phase_val.style.hide(axis='index'))
    display(Markdown("*Δ > ±0.3 suggests phase-specific pattern.*"))

# ============================================================
# 3.2.3 Temporal LQ validation: 2015+ vs full period
# ============================================================

display(Markdown("### 3.2.3 Temporal LQ validation (2015+ vs full period)"))
display(Markdown("*Checks if historical patterns persist in recent data.*"))

query_recent = """
WITH 
recent_study_countries AS (
    SELECT DISTINCT l.study_id, l.country
    FROM locations l
    JOIN v_studies_clean s ON l.study_id = s.study_id
    WHERE l.country IS NOT NULL AND l.country != ''
      AND s.start_year >= 2015
      AND s.is_start_year_in_scope = 1
),
recent_study_conditions AS (
    SELECT DISTINCT c.study_id, LOWER(TRIM(c.condition_name)) AS condition_standardized
    FROM conditions c
    JOIN v_studies_clean s ON c.study_id = s.study_id
    WHERE c.condition_name IS NOT NULL AND TRIM(c.condition_name) != ''
      AND s.start_year >= 2015
      AND s.is_start_year_in_scope = 1
),
recent_country_condition AS (
    SELECT sc.country, scond.condition_standardized, COUNT(DISTINCT sc.study_id) AS n_trials
    FROM recent_study_countries sc
    JOIN recent_study_conditions scond ON sc.study_id = scond.study_id
    GROUP BY sc.country, scond.condition_standardized
),
recent_country_totals AS (
    SELECT country, COUNT(DISTINCT study_id) AS country_total
    FROM recent_study_countries GROUP BY country
),
recent_condition_totals AS (
    SELECT condition_standardized, COUNT(DISTINCT study_id) AS condition_total
    FROM recent_study_conditions GROUP BY condition_standardized
),
recent_global AS (
    SELECT COUNT(DISTINCT sc.study_id) AS global_total
    FROM recent_study_countries sc
    JOIN recent_study_conditions scond ON sc.study_id = scond.study_id
)
SELECT 
    rcc.country, rcc.condition_standardized, rcc.n_trials,
    rct.country_total, rcond.condition_total, rg.global_total,
    ROUND((CAST(rcc.n_trials AS REAL) / rct.country_total) / 
          (CAST(rcond.condition_total AS REAL) / rg.global_total), 3) AS lq_recent
FROM recent_country_condition rcc
JOIN recent_country_totals rct ON rcc.country = rct.country
JOIN recent_condition_totals rcond ON rcc.condition_standardized = rcond.condition_standardized
CROSS JOIN recent_global rg
WHERE rct.country_total >= 30 AND rcond.condition_total >= 50 AND rcc.n_trials >= 3
"""

df_recent_spec = pd.read_sql(query_recent, conn)

temporal_validation = []
for country in major_countries[:5]:
    country_top_full = df_high_lq[df_high_lq['country'] == country].head(1)
    if len(country_top_full) == 0:
        continue
    
    top_cond = country_top_full.iloc[0]['condition_standardized']
    full_lq = country_top_full.iloc[0]['location_quotient']
    
    recent_match = df_recent_spec[
        (df_recent_spec['country'] == country) & 
        (df_recent_spec['condition_standardized'] == top_cond)
    ]
    
    if len(recent_match) > 0:
        recent_lq = recent_match.iloc[0]['lq_recent']
        recent_n = recent_match.iloc[0]['n_trials']
        lq_delta = recent_lq - full_lq
        lq_delta_pct = (lq_delta / full_lq * 100) if full_lq > 0 else 0
        
        temporal_validation.append({
            'Country': country,
            'Condition': top_cond[:22] + '...' if len(top_cond) > 22 else top_cond,
            'Full LQ': f"{full_lq:.2f}",
            '2015+ LQ': f"{recent_lq:.2f}",
            'ΔLQ': f"{lq_delta:+.2f} ({lq_delta_pct:+.0f}%)",
        })
    else:
        temporal_validation.append({
            'Country': country,
            'Condition': top_cond[:22] + '...' if len(top_cond) > 22 else top_cond,
            'Full LQ': f"{full_lq:.2f}",
            '2015+ LQ': '—',
            'ΔLQ': 'insufficient',
        })

if temporal_validation:
    df_temporal = pd.DataFrame(temporal_validation)
    display(df_temporal.style.hide(axis='index'))
    display(Markdown("*ΔLQ > ±20% suggests historical patterns may not reflect current activity.*"))

# Threshold sensitivity
thresholds = [1.5, 2.0, 2.5]
sensitivity = [len(df_spec[df_spec['location_quotient'] > t]) for t in thresholds]
display(Markdown(f"**Threshold sensitivity:** LQ>1.5: {sensitivity[0]:,} | LQ>2.0: {sensitivity[1]:,} | LQ>2.5: {sensitivity[2]:,} pairs"))

In [None]:
# ============================================================
# 3.2.3 LQ Heatmap: Top countries × Top conditions
# ============================================================

# Prepare data for heatmap: top 8 countries × top 12 conditions
top8_countries = df_country.head(8)['primary_country'].tolist()
top12_conditions = (
    df_spec
    .groupby('condition_standardized')['n_trials'].sum()
    .nlargest(12)
    .index.tolist()
)

# Filter and pivot for LQ values
df_heatmap_raw = df_spec[
    (df_spec['country'].isin(top8_countries)) &
    (df_spec['condition_standardized'].isin(top12_conditions))
].pivot_table(
    index='country',
    columns='condition_standardized',
    values='location_quotient',
    fill_value=1.0
)

# Also pivot for n_trials (sample size)
df_n_raw = df_spec[
    (df_spec['country'].isin(top8_countries)) &
    (df_spec['condition_standardized'].isin(top12_conditions))
].pivot_table(
    index='country',
    columns='condition_standardized',
    values='n_trials',
    fill_value=0
)

df_heatmap_raw = df_heatmap_raw.reindex(top8_countries)
df_n_raw = df_n_raw.reindex(top8_countries).fillna(0)

short_names = {c: c[:20] + '...' if len(c) > 20 else c for c in df_heatmap_raw.columns}
df_heatmap = df_heatmap_raw.rename(columns=short_names)

# Create text annotations with n indicator for low-sample cells
# † = n < 50 (interpret with caution)
LOW_N_THRESHOLD = 50
text_annotations = []
low_n_cells = []
for i in range(df_heatmap_raw.shape[0]):
    row = []
    for j in range(df_heatmap_raw.shape[1]):
        raw_val = df_heatmap_raw.values[i, j]
        n_val = df_n_raw.values[i, j] if i < df_n_raw.shape[0] and j < df_n_raw.shape[1] else 0
        
        # Format LQ with n indicator
        if n_val < LOW_N_THRESHOLD:
            suffix = "†"
            low_n_cells.append((df_heatmap_raw.index[i], df_heatmap_raw.columns[j][:15], int(n_val), raw_val))
        else:
            suffix = ""
        
        if raw_val > 2.0:
            row.append(f"{raw_val:.1f}*{suffix}")
        elif raw_val < 0.5:
            row.append(f"{raw_val:.2f}*{suffix}")
        else:
            row.append(f"{raw_val:.1f}{suffix}")
    text_annotations.append(row)

min_lq = df_heatmap_raw.values.min()
max_lq = df_heatmap_raw.values.max()

# Create heatmap
fig_heatmap = go.Figure(data=go.Heatmap(
    z=np.clip(df_heatmap.values, 0.5, 2.0),
    x=list(df_heatmap.columns),
    y=list(df_heatmap.index),
    colorscale=[
        [0, '#2166ac'],
        [0.4, '#f7f7f7'],
        [0.6, '#f7f7f7'],
        [1, '#b2182b']
    ],
    zmin=0.5,
    zmax=2.0,
    text=text_annotations,
    texttemplate='%{text}',
    textfont=dict(size=10),
    hovertemplate='<b>%{y}</b> × %{x}<br>LQ = %{z:.2f}<extra></extra>',
    colorbar=dict(
        title='LQ',
        tickvals=[0.5, 1.0, 1.5, 2.0],
        ticktext=['≤0.5', '1.0', '1.5', '≥2.0'],
    ),
))

fig_heatmap.update_layout(
    title=dict(
        text='<b>Location Quotient: Country × Condition</b><br>'
             '<span style="font-size:11px;color:gray">Red = over-represented | White = average | Blue = under-represented | †=n<50</span>',
        x=0.5,
        xanchor='center',
    ),
    xaxis=dict(title=None, tickangle=45, tickfont=dict(size=10)),
    yaxis=dict(title=None, tickfont=dict(size=11)),
    height=450,
    width=900,
    template='plotly_white',
    margin=dict(l=120, r=80, t=80, b=120),
)

fig_heatmap.show()

# Notes
display(Markdown(f"""
*Scale clamped to 0.5–2.0. Actual range: [{min_lq:.2f}, {max_lq:.1f}].*

**† = n < {LOW_N_THRESHOLD}:** These cells have small sample sizes; LQ estimates are unstable.
"""))

if low_n_cells:
    n_low = len(low_n_cells)
    display(Markdown(f"*{n_low} of {df_heatmap_raw.size} cells have n < {LOW_N_THRESHOLD}.*"))


In [None]:
# ============================================================
# 3.3 LQ methodology note
# ============================================================

display(Markdown(f"""
### 3.3 Why LQ instead of χ²?

Trials map to ~{mean_conds_per_trial:.1f} conditions on average, violating χ² independence assumption. LQ provides magnitude (over-/under-representation) without inflated p-values.
"""))


### 3.3 LQ vs χ²

χ² requires independent observations; here trials map to ~1.8 conditions each, inflating sample size. LQ measures magnitude: LQ=1.5 means 50% over-representation, LQ=2.0 means twice expected share.


---

## 4. Temporal Trends in Geographic Distribution

**Question:** Has the geographic distribution of trials shifted over time?

In [None]:
# ============================================================
# 4.1 Temporal trends with distributional shift metrics
# ============================================================

from scipy.spatial.distance import jensenshannon

df_geo['start_cohort'] = pd.cut(df_geo['start_year'], bins=COHORT_BINS, labels=COHORT_LABELS)

display(Markdown(f"""
### 4.1 Temporal analysis

**Cohorts:** {', '.join(COHORT_LABELS)}
"""))

# ============================================================
# 4.1.1 Jensen-Shannon Divergence between cohorts
# ============================================================

all_countries = df_geo['primary_country'].unique()
n_categories = len(all_countries)

cohort_distributions = {}
for cohort in COHORT_LABELS:
    cohort_data = df_geo[df_geo['start_cohort'] == cohort]
    if len(cohort_data) > 0:
        counts = cohort_data['primary_country'].value_counts()
        dist = pd.Series(0.0, index=all_countries)
        dist.update(counts / counts.sum())
        cohort_distributions[cohort] = dist.values

jsd_results = []
cohort_list = list(cohort_distributions.keys())
for i in range(len(cohort_list) - 1):
    c1, c2 = cohort_list[i], cohort_list[i + 1]
    jsd = jensenshannon(cohort_distributions[c1], cohort_distributions[c2])
    jsd_results.append({'Comparison': f'{c1} → {c2}', 'JSD': jsd})

if len(cohort_list) >= 2:
    jsd_first_last = jensenshannon(cohort_distributions[cohort_list[0]], cohort_distributions[cohort_list[-1]])
    jsd_results.append({'Comparison': f'{cohort_list[0]} → {cohort_list[-1]}', 'JSD': jsd_first_last})

df_jsd = pd.DataFrame(jsd_results)
max_jsd = df_jsd['JSD'].max()

# JSD interpretation
if max_jsd < 0.05:
    jsd_interp = "negligible shift"
elif max_jsd < 0.10:
    jsd_interp = "minor shift"
elif max_jsd < 0.20:
    jsd_interp = "moderate shift"
else:
    jsd_interp = "substantial shift"

# Bootstrap CI for JSD
np.random.seed(44)
n_jsd_bootstrap = 500
first_cohort_data = df_geo[df_geo['start_cohort'] == cohort_list[0]]['primary_country'].values
last_cohort_data = df_geo[df_geo['start_cohort'] == cohort_list[-1]]['primary_country'].values

jsd_bootstrap = []
for _ in range(n_jsd_bootstrap):
    boot_first = np.random.choice(first_cohort_data, size=len(first_cohort_data), replace=True)
    boot_last = np.random.choice(last_cohort_data, size=len(last_cohort_data), replace=True)
    all_countries_boot = np.union1d(np.unique(boot_first), np.unique(boot_last))
    dist_first = pd.Series(boot_first).value_counts().reindex(all_countries_boot, fill_value=0) / len(boot_first)
    dist_last = pd.Series(boot_last).value_counts().reindex(all_countries_boot, fill_value=0) / len(boot_last)
    jsd_boot = jensenshannon(dist_first.values, dist_last.values)
    jsd_bootstrap.append(jsd_boot)

jsd_ci_low = np.percentile(jsd_bootstrap, 2.5)
jsd_ci_high = np.percentile(jsd_bootstrap, 97.5)

display(Markdown(f"""
### 4.1.1 Distributional shift (Jensen-Shannon Divergence)

JSD ranges 0 (identical) to ~0.83 (maximally different). Thresholds: <0.05 negligible, 0.05–0.10 minor, 0.10–0.20 moderate, >0.20 substantial.

*Note: Thresholds are heuristic for ~{n_categories} countries. A permutation null would give more rigorous calibration.*
"""))
display(df_jsd.style.format({'JSD': '{:.3f}'}).hide(axis='index'))
display(Markdown(f"**{cohort_list[0]} → {cohort_list[-1]}:** JSD = {jsd_first_last:.3f} [95% CI: {jsd_ci_low:.3f}–{jsd_ci_high:.3f}] → {jsd_interp}"))

# ============================================================
# 4.1.2 Global share by cohort
# ============================================================

top10_countries = df_country.head(10)['primary_country'].tolist()
cohort_global_totals = df_geo.groupby('start_cohort', observed=True).size().reset_index(name='global_total')

temporal_global = (
    df_geo[df_geo['primary_country'].isin(top10_countries)]
    .groupby(['start_cohort', 'primary_country'], observed=True)
    .size()
    .reset_index(name='n_trials')
)
temporal_global = temporal_global.merge(cohort_global_totals, on='start_cohort')
temporal_global['global_share'] = temporal_global['n_trials'] / temporal_global['global_total'] * 100

temporal_global_pivot = temporal_global.pivot_table(
    index='primary_country', columns='start_cohort', values='global_share', fill_value=0, observed=True
).round(1).reindex(top10_countries)

display(Markdown("### 4.1.2 Global share by country and cohort (%)"))
display(temporal_global_pivot.style.format("{:.1f}%"))

# Top-10 combined share
top10_combined = temporal_global.groupby('start_cohort', observed=True)['n_trials'].sum().reset_index()
top10_combined = top10_combined.merge(cohort_global_totals, on='start_cohort')
top10_combined['top10_share'] = top10_combined['n_trials'] / top10_combined['global_total'] * 100

if len(top10_combined) >= 2:
    delta_top10 = top10_combined.iloc[-1]['top10_share'] - top10_combined.iloc[0]['top10_share']
    direction = "diversifying" if delta_top10 < -3 else "concentrating" if delta_top10 > 3 else "stable"
    display(Markdown(f"**Top-10 combined:** {top10_combined.iloc[0]['top10_share']:.0f}% → {top10_combined.iloc[-1]['top10_share']:.0f}% ({delta_top10:+.0f}pp, {direction})"))

# ============================================================
# 4.1.3 Formal trend test (Spearman correlation)
# ============================================================

trend_tests = []
for country in ['United States', 'China', 'India']:
    if country in temporal_global_pivot.index:
        shares = temporal_global_pivot.loc[country].values
        cohort_idx = np.arange(len(shares))
        if len(shares) >= 3:
            rho, p_val = spearmanr(cohort_idx, shares)
            trend_dir = "increasing" if rho > 0.3 else "decreasing" if rho < -0.3 else "no clear trend"
            sig = "sig." if p_val < 0.05 else "n.s."
            trend_tests.append({
                'Country': country,
                'ρ': f"{rho:.2f}",
                'p': f"{p_val:.3f}" if p_val >= 0.001 else "<.001",
                'Direction': f"{trend_dir} ({sig})"
            })

if trend_tests:
    display(Markdown("**Trend tests (Spearman ρ of share vs cohort index):**"))
    display(pd.DataFrame(trend_tests).style.hide(axis='index'))

# ============================================================
# 4.1.4 Key country shifts
# ============================================================

top5_countries = df_country.head(5)['primary_country'].tolist()
earliest_cohort, latest_cohort = '2000-2009', '2020-2025'

if earliest_cohort in temporal_global_pivot.columns and latest_cohort in temporal_global_pivot.columns:
    shifts = []
    for country in top5_countries:
        early = temporal_global_pivot.loc[country, earliest_cohort]
        late = temporal_global_pivot.loc[country, latest_cohort]
        delta = late - early
        direction_arrow = '↑' if delta > 5 else '↓' if delta < -5 else '→'
        shifts.append(f"{country}: {early:.0f}% → {late:.0f}% ({delta:+.0f}pp) {direction_arrow}")
    
    display(Markdown(f"**Key shifts ({earliest_cohort} → {latest_cohort}):** " + " | ".join(shifts)))

display(Markdown("""
**Confounders:**
- FDAAA 2007 mandated US registration → early US shares inflated
- Non-US registries growing → ClinicalTrials.gov share declining
- Phase/condition mix varies by period
"""))


---

## 5. Summary & Implications

In [None]:
# ============================================================
# 5.1 Summary (Descriptive)
# ============================================================

top_country_name = df_country.iloc[0]['primary_country']
top_country_share = df_country.iloc[0]['market_share'] * 100

n_high_lq = len(df_high_lq)

# Fallback for variables that may not exist
direction = direction if 'direction' in dir() else "unknown"
trend_tests_exist = 'trend_tests' in dir() and len(trend_tests) > 0

display(Markdown(f"""
## Summary

**This section summarizes descriptive patterns. The methods do not support causal or inferential claims.**

### Concentration (§2)

Site-weighted HHI = {hhi_site_weighted:.0f}, which falls in the "{hhi_site_interp}" range using DOJ/FTC thresholds.

- Primary-country HHI ({hhi:.0f}) is {hhi_delta_method:.0f} points higher due to multinational trial assignment.
- Site-weighted HHI counts participation, not capacity—a country with 100 sites contributes 100× a country with 1 site, regardless of enrollment volume.
- Neither method controls for phase or sponsor composition.

### Specialization (§3)

{n_high_lq:,} country–condition pairs have LQ > 1.5.

**These are not validated specializations.** LQ is computed on:
- Unnormalized condition labels (synonyms counted separately)
- Multi-condition trials (inflating denominators for co-occurring conditions)
- No confidence intervals (values are point estimates, not stable for small n)

Sensitivity analyses (§3.2.1–3.2.3) show instability in several high-LQ pairs when restricted to single-condition trials, Phase 3, or 2015+.

### Temporal (§4)

JSD = {max_jsd:.3f} between earliest and latest cohorts ({jsd_interp}).

- JSD thresholds are heuristic; no formal null distribution was computed.
- Spearman trend tests have low power with only {len(cohort_list)} cohorts.
- Confounders: FDAAA 2007 (US mandate), growth of non-US registries.

### For site selection

LQ indicates registration patterns, not current capacity or performance. Before using this analysis:

1. Validate apparent specializations against Phase 3 and 2015+ data
2. Recognize that condition labels are fragmented
3. Supplement with enrollment rate and site experience data
"""))

---

## Appendix A: Site Complexity by Phase/Sponsor

*Supplementary analysis for feasibility planning; not directly addressing the geographic distribution question.*

In [None]:
# ============================================================
# A.1 Site complexity by phase
# ============================================================

df_phase_sites = df_geo[
    (df_geo['is_interventional'] == 1) &
    (df_geo['phase_group'].notna()) &
    (df_geo['phase_group'] != 'Not Applicable') &
    (df_geo['phase_group'] != 'Other')
].copy()

df_phase_sites['phase_group'] = pd.Categorical(df_phase_sites['phase_group'], categories=PHASE_ORDER_CLINICAL, ordered=True)

phase_summary = (
    df_phase_sites
    .groupby('phase_group', observed=True)
    .agg(
        n_trials=('study_id', 'nunique'),
        median_sites=('n_sites', 'median'),
        mean_sites=('n_sites', 'mean'),
        q75_sites=('n_sites', lambda x: x.quantile(0.75)),
        pct_multinational=('is_multinational', 'mean'),
    )
    .reset_index()
)
phase_summary['pct_multinational'] = phase_summary['pct_multinational'] * 100

display(Markdown("### A.1 Site complexity by phase (interventional)"))
display(
    phase_summary
    .rename(columns={
        'phase_group': 'Phase', 'n_trials': 'N', 'median_sites': 'Median',
        'mean_sites': 'Mean', 'q75_sites': 'Q75', 'pct_multinational': 'Multinational %',
    })
    .style.format({'N': '{:,.0f}', 'Median': '{:.0f}', 'Mean': '{:.1f}', 'Q75': '{:.0f}', 'Multinational %': '{:.1f}%'})
    .hide(axis='index')
)

# Key values for reference
p1_sites = phase_summary.loc[phase_summary['phase_group'] == 'Phase 1', 'median_sites'].values
p3_sites = phase_summary.loc[phase_summary['phase_group'] == 'Phase 3', 'median_sites'].values
p1_val = int(p1_sites[0]) if len(p1_sites) > 0 else 'N/A'
p3_val = int(p3_sites[0]) if len(p3_sites) > 0 else 'N/A'
site_ratio = p3_val / p1_val if isinstance(p1_val, int) and p1_val > 0 else 'N/A'

display(Markdown(f"Phase 3 median = {p3_val} sites vs Phase 1 = {p1_val} ({site_ratio:.0f}× ratio)"))

In [None]:
# ============================================================
# A.2 Sponsor effect on site count (Phase 3)
# ============================================================

p3_data = df_phase_sites[df_phase_sites['phase_group'] == 'Phase 3'].copy()
p3_industry = p3_data[p3_data['is_industry_sponsor'] == 1]['n_sites']
p3_non_industry = p3_data[p3_data['is_industry_sponsor'] == 0]['n_sites']

u_stat, p_mw = mannwhitneyu(p3_industry, p3_non_industry, alternative='two-sided')
n1, n2 = len(p3_industry), len(p3_non_industry)
r_biserial = 1 - (2 * u_stat) / (n1 * n2)
effect_label = interpret_effect_size(r_biserial, metric="r")
median_ratio = p3_industry.median() / p3_non_industry.median() if p3_non_industry.median() > 0 else float('inf')

display(Markdown(f"""
### A.2 Sponsor effect (Phase 3) — Descriptive Only

| Sponsor | N | Median Sites |
|---------|---|--------------|
| Industry | {n1:,} | {p3_industry.median():.0f} |
| Non-industry | {n2:,} | {p3_non_industry.median():.0f} |

Mann-Whitney r = {r_biserial:.2f} ({effect_label}), p < 0.001. Industry Phase 3 trials have {median_ratio:.1f}× higher median site count.
"""))

# ============================================================
# Stratified analysis: control for therapeutic area (oncology vs non-oncology)
# ============================================================

# Identify oncology trials (heuristic: condition contains "cancer", "tumor", "carcinoma", etc.)
oncology_keywords = ['cancer', 'tumor', 'tumour', 'carcinoma', 'leukemia', 'lymphoma', 'melanoma', 'sarcoma', 'oncolog']

# Merge with conditions to identify oncology trials
df_conds = pd.read_sql("""
    SELECT c.study_id, LOWER(c.condition_name) AS condition_lower
    FROM conditions c
    JOIN v_studies_clean s ON c.study_id = s.study_id
    WHERE s.is_start_year_in_scope = 1
""", conn)

oncology_studies = df_conds[df_conds['condition_lower'].str.contains('|'.join(oncology_keywords), na=False)]['study_id'].unique()
p3_data['is_oncology'] = p3_data['study_id'].isin(oncology_studies).astype(int)

# Stratified comparison
strat_results = []
for is_onc, label in [(1, 'Oncology'), (0, 'Non-oncology')]:
    stratum = p3_data[p3_data['is_oncology'] == is_onc]
    ind = stratum[stratum['is_industry_sponsor'] == 1]['n_sites']
    non_ind = stratum[stratum['is_industry_sponsor'] == 0]['n_sites']
    
    if len(ind) >= 10 and len(non_ind) >= 10:
        u, p = mannwhitneyu(ind, non_ind, alternative='two-sided')
        r = 1 - (2 * u) / (len(ind) * len(non_ind))
        ratio = ind.median() / non_ind.median() if non_ind.median() > 0 else float('inf')
        
        strat_results.append({
            'Stratum': label,
            'Industry (n, median)': f"{len(ind):,}, {ind.median():.0f}",
            'Non-industry (n, median)': f"{len(non_ind):,}, {non_ind.median():.0f}",
            'Ratio': f"{ratio:.1f}×",
            'r': f"{r:.2f}",
        })
    else:
        strat_results.append({
            'Stratum': label,
            'Industry (n, median)': f"{len(ind):,}, —" if len(ind) < 10 else f"{len(ind):,}, {ind.median():.0f}",
            'Non-industry (n, median)': f"{len(non_ind):,}, —" if len(non_ind) < 10 else f"{len(non_ind):,}, {non_ind.median():.0f}",
            'Ratio': 'n/a',
            'r': 'n/a',
        })

df_strat = pd.DataFrame(strat_results)
display(Markdown("**Stratified by therapeutic area:**"))
display(df_strat.style.hide(axis='index'))

# Check if sponsor effect attenuates within strata
onc_r = float(strat_results[0]['r']) if strat_results[0]['r'] != 'n/a' else None
non_onc_r = float(strat_results[1]['r']) if strat_results[1]['r'] != 'n/a' else None
overall_r = r_biserial

if onc_r is not None and non_onc_r is not None:
    avg_strat_r = (abs(onc_r) + abs(non_onc_r)) / 2
    confounding_note = "attenuates" if avg_strat_r < abs(overall_r) * 0.8 else "persists"
    display(Markdown(f"""
*Effect {confounding_note} after stratification (overall r = {overall_r:.2f}, oncology r = {onc_r:.2f}, non-oncology r = {non_onc_r:.2f}). This suggests {'therapeutic area confounding' if confounding_note == 'attenuates' else 'sponsor effect is robust to therapeutic area'}.*
"""))
else:
    display(Markdown("*Insufficient data for stratified comparison.*"))

display(Markdown("""
**Interpretation:** This appendix provides descriptive context for feasibility planning. The sponsor–site relationship reflects multiple factors (indication, enrollment targets, regulatory requirements) that are not disentangled here.
"""))

---

## Cleanup

In [None]:
# Close database connection
conn.close()
print("Done.")