In [4]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from statsmodels.tsa.holtwinters import ExponentialSmoothing
import plotly.offline as py

In [5]:

# Define WHO countries by region
who_countries = {
    "AFR": ["Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cameroon", "Cape Verde", "Cabo Verde", "Central African Republic", "Chad", "Comoros", "Ivory Coast", "Democratic Republic of the Congo", "Equatorial Guinea", "Eritrea", "Ethiopia", "Gabon", "Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Kenya", "Lesotho", "Liberia", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Mozambique", "Namibia", "Niger", "Nigeria", "Republic of the Congo", "Rwanda", "São Tomé and Príncipe", "Senegal", "Seychelles", "Sierra Leone", "South Africa", "South Sudan", "Eswatini", "Togo", "Uganda", "Tanzania", "Zambia", "Zimbabwe"],
    "AMR": ["Peru", "Paraguay", "Saint Kitts and Nevis", "Antigua and Barbuda", "Argentina", "Bahamas", "Barbados", "Belize", "Bolivia", "Brazil", "Canada", "Chile", "Colombia", "Costa Rica", "Cuba", "Dominica", "Dominican Republic", "Ecuador", "El Salvador", "Grenada", "Guatemala", "Guyana", "Haiti", "Honduras", "Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Lucia", "Saint Vincent and the Grenadines", "Suriname", "Trinidad and Tobago", "the United States of America", "Uruguay", "Venezuela"],
    "SEAR": ["Bangladesh", "Bhutan", "Democratic People's Republic of Korea", "India", "Maldives", "Myanmar", "Nepal", "Sri Lanka", "Thailand", "Timor-Leste"],
    "EUR": ["Albania", "Andorra", "Armenia", "Austria", "Azerbaijan", "Belarus", "Belgium", "Bosnia and Herzegovina", "Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France", "Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Israel", "Italy", "Kazakhstan", "Kyrgyzstan", "Latvia", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco", "Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania", "Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland", "Tajikistan", "Turkey", "Turkmenistan", "Ukraine", "United Kingdom", "Uzbekistan"],
    "EMR": ["Libya", "Afghanistan", "Bahrain", "Djibouti", "Egypt", "Iran", "Iraq", "Jordan", "Kuwait", "Israel", "Oman", "Pakistan", "Qatar", "Saudi Arabia", "Somalia", "Sudan", "Syria", "Tunisia", "United Arab Emirates", "Yemen", "Morocco"],
    "WPR": ["Australia", "Brunei", "Cambodia", "China", "Cook Islands", "Fiji", "Indonesia", "Japan", "Kiribati", "Laos", "Malaysia", "Marshall Islands", "Micronesia", "Mongolia", "Nauru", "New Zealand", "Niue", "Palau", "Papua New Guinea", "Philippines", "Samoa", "Singapore", "Solomon Islands", "South Korea", "Taiwan", "Tonga", "Tuvalu", "Vanuatu", "Vietnam"]
}

# Some countries have small variations in name, so a second dictionary accounts for that
name_recode = {
    "Bolivia (Plurinational State of)": "Bolivia",
    "Brunei Darussalam": "Brunei",
    "Côte d'Ivoire": "Ivory Coast",
    "Czechia": "Czech Republic",
    "Lao People's Democratic Republic": "Laos",
    "Republic of Moldova": "Moldova",
    "Russian Federation": "Russia",
    "Sao Tome and Principe": "São Tomé and Príncipe",
    "Türkiye": "Turkey",
    "United States of America": "the United States of America",
    "Micronesia (Federated States of)": "Micronesia",
    "Netherlands (Kingdom of the)": "Netherlands",
    "Republic of Korea": "South Korea",
    "United Kingdom of Great Britain and Northern Ireland": "United Kingdom",
    "United Republic of Tanzania": "Tanzania"
}

# Read the vaccination data (assuming full dataset for time series)
vax = pd.read_csv("who_vax_country.tsv", sep="\t", header=0)

# Subset for specific antigen description
vax = vax[vax['ANTIGEN_DESCRIPTION'] == "HPV Vaccination program coverage, first dose, females"]

# Convert COVERAGE to integer
vax['COVERAGE'] = vax['COVERAGE'].astype('Int64')

# Recode the NAME column
vax['NAME'] = vax['NAME'].replace(name_recode)

# Create who_region_df
who_region_df = pd.DataFrame([
    (country, region) 
    for region, countries in who_countries.items() 
    for country in countries
], columns=['NAME', 'REGION'])

# join vax with who_region_df
vax = vax.merge(who_region_df, on='NAME', how='left')

# Read metadata
metadata = pd.read_csv("vax_metadata.csv", index_col=0)  # Adjust path if necessary
metadata = metadata.rename(columns={'ISO_3_CODE': 'CODE'})
metadata_selected = metadata[['CODE', 'HPV_YEAR_INTRODUCTION', 'HPV_INT_DOSES']]

# join vax with metadata_selected
vax = vax.merge(metadata_selected, on='CODE', how='left')

# Update COVERAGE based on conditions
vax['COVERAGE'] = np.where(
    (vax['YEAR'] < vax['HPV_YEAR_INTRODUCTION']) | 
    (vax['HPV_INT_DOSES'] == "Not yet introduced"),
    0,
    vax['COVERAGE']
)

# Drop the HPV_YEAR_INTRODUCTION column
vax = vax.drop(columns=['HPV_YEAR_INTRODUCTION'])

# Convert YEAR to datetime upfront for consistency
vax["YEAR"] = pd.to_datetime(vax["YEAR"].astype(int), format='%Y')

# Recalculate vax_name (assumes COVERAGE already cleaned appropriately)
vax_name = vax.groupby(["REGION", "YEAR"]).agg(
    mean_coverage=("COVERAGE", "mean"),
    se_coverage=("COVERAGE", lambda x: 0 if x.notna().sum() <= 1 
                 else x.std(ddof=1) / np.sqrt(x.notna().sum()))
).reset_index()

# Convert REGION to categorical
vax_name["REGION"] = pd.Categorical(vax_name["REGION"])
regions = vax_name["REGION"].cat.categories


In [6]:
# Create subplots (2 rows, 3 columns for 6 regions)
fig2 = make_subplots(
    rows=2,
    cols=3,
    subplot_titles=[f"Region: {region}" for region in regions]
)

for idx, region in enumerate(regions):
    row = (idx // 3) + 1
    col = (idx % 3) + 1

    region_data = vax_name[vax_name["REGION"] == region].dropna(subset=["mean_coverage"]).sort_values("YEAR")
    country_data = vax[vax["REGION"] == region]

    # Set index for region_data (already datetime)
    region_data = region_data.set_index("YEAR")
    region_data.index.freq = 'YS'  # Explicitly set frequency to suppress warnings

    # Plot individual country lines
    for country in country_data["NAME"].unique():
        country_subset = country_data[country_data["NAME"] == country].dropna(subset=["COVERAGE"])
        if not country_subset.empty:
            fig2.add_trace(
                go.Scatter(
                    x=country_subset["YEAR"],
                    y=country_subset["COVERAGE"],
                    mode="lines",
                    line=dict(color="red", width=0.5),
                    opacity=0.3,
                    hovertemplate=f"{country}<br>Year: %{{x}}<br>Coverage: %{{y}}%<extra></extra>",
                    showlegend=False,
                ),
                row=row,
                col=col,
            )

    # Ribbon: ± SE
    fig2.add_trace(
        go.Scatter(
            x=region_data.index,
            y=region_data["mean_coverage"] + region_data["se_coverage"],
            mode="lines",
            line=dict(width=0),
            hoverinfo="skip",
            showlegend=False,
        ),
        row=row,
        col=col,
    )

    fig2.add_trace(
        go.Scatter(
            x=region_data.index,
            y=region_data["mean_coverage"] - region_data["se_coverage"],
            mode="lines",
            line=dict(width=0),
            fill="tonexty",
            fillcolor="rgba(211,211,211,0.75)",
            hoverinfo="skip",
            showlegend=False,
        ),
        row=row,
        col=col,
    )

    # Regional mean
    fig2.add_trace(
        go.Scatter(
            x=region_data.index,
            y=region_data["mean_coverage"],
            mode="lines",
            line=dict(color="black", width=2),
            hovertemplate="Regional Mean<br>Year: %{x}<br>Coverage: %{y}%<extra></extra>",
            showlegend=False,
        ),
        row=row,
        col=col,
    )

    # Forecast to 2030
    last_year = region_data.index.max()
    horizon = (pd.Timestamp('2030-01-01') - last_year).days // 365

    if horizon > 0 and len(region_data) >= 2:
        ts = region_data["mean_coverage"]

        try:
            model = ExponentialSmoothing(ts, trend="add", damped_trend=True).fit()
            forecast_years = pd.date_range(start=last_year + pd.DateOffset(years=1), end='2030-01-01', freq='YS')
            forecast_mean = model.forecast(len(forecast_years))

            sims = model.simulate(nsimulations=len(forecast_years), repetitions=1000, error="add")
            lower = sims.quantile(0.05, axis=1)
            upper = sims.quantile(0.95, axis=1)

            forecast_df = pd.DataFrame({
                "YEAR": forecast_years,
                "mean": forecast_mean.values,
                "lower": lower.values,
                "upper": upper.values,
            })

            # Forecast CI ribbon
            fig2.add_trace(
                go.Scatter(
                    x=forecast_df["YEAR"],
                    y=forecast_df["upper"],
                    mode="lines",
                    line=dict(width=0),
                    hoverinfo="skip",
                    showlegend=False,
                ),
                row=row,
                col=col,
            )

            fig2.add_trace(
                go.Scatter(
                    x=forecast_df["YEAR"],
                    y=forecast_df["lower"],
                    mode="lines",
                    fill="tonexty",
                    fillcolor="rgba(150,150,150,0.5)",
                    hoverinfo="skip",
                    showlegend=False,
                ),
                row=row,
                col=col,
            )

            # Forecast mean
            fig2.add_trace(
                go.Scatter(
                    x=forecast_df["YEAR"],
                    y=forecast_df["mean"],
                    mode="lines",
                    line=dict(color="black", dash="dash", width=2),
                    hovertemplate="Forecast<br>Year: %{x}<br>Coverage: %{y:.1f}%<extra></extra>",
                    showlegend=False,
                ),
                row=row,
                col=col,
            )
        except Exception as e:
            print(f"Forecast failed for region {region}: {e}")

# Layout
fig2.update_layout(
    font=dict(family="Times New Roman"),
    title="HPV Vaccine Coverage by Region",
    height=800,
    width=1200,
)

fig2.update_xaxes(title_text="Year", autorange=True)
fig2.update_yaxes(title_text="HPV Coverage (%)", range=[0, 100], autorange=False)

fig2.add_annotation(
    text="Shaded lines: individual countries · Ribbon: ±SE · Solid lines: regional mean · Dashed: forecast",
    xref="paper",
    yref="paper",
    x=0.5,
    y=-0.05,
    showarrow=False,
    font=dict(size=12),
)

fig2.show()