In [36]:
import os
import numpy as np
import pandas as pd
import pymc as pm
import arviz as az
from datetime import datetime, date, timedelta
from typing import Any, Literal
import matplotlib.pyplot as plt
import pickle
from collections.abc import Callable

In [None]:
def fit_model(
    dataset: pd.DataFrame,
    sigmoid_kind: Literal["logistic", "harvey"] = "logistic",
    n_samples: int = 2000,
    n_tune: int = 1000,
    top_n: int = 3,
) -> az.InferenceData:
    """Fit a Bayesian model to the dataset using the specified sigmoid function.

    If the dataset contains more than one benchmark, a joint model with shared hyperparameters is fitted.

    This allows benchmarks to inform each other through common priors on:
    - L_mu, L_sigma: (upper) asymptote distribution parameters
    - k_mu, k_sigma: growth rate distribution parameters
    - xi_base_mu, xi_base_sigma: noise level distribution parameters
    - s_mu, s_sigma: skewness distribution parameters

    Args:
        dataset: A dataset containing benchmark data. It must contain the columns 'score', 'release_date', 'benchmark' and 'lower_bound' of types float, datetime, string, and float respectively.
        n_samples: Number of MCMC samples to draw from the posterior distribution.
        n_tune: Number of tuning steps for the MCMC sampler.
        top_n: Number of top scores to consider when fitting the model. If top_n=1, only the frontier scores are used.

    Returns:
        An arviz InferenceData object containing the posterior samples.
    """
    # Check validity of the dataset
    required_columns_and_types = {
        "score": pd.api.types.is_float_dtype,
        "release_date": pd.api.types.is_datetime64_any_dtype,
        "benchmark": pd.api.types.is_string_dtype,
        "lower_bound": pd.api.types.is_float_dtype,
    }
    for column, check_type in required_columns_and_types.items():
        if column not in dataset.columns:
            raise ValueError(f"Dataset must contain the column '{column}'.")
        if not check_type(dataset[column]):
            raise TypeError(f"Column '{column}' must be of type {check_type.__name__}.")
        if dataset[column].isnull().any():
            raise ValueError(f"Column '{column}' must not contain null values.")

    # Filter in the top_n frontier scores
    dataset = (
        dataset.sort_values(["benchmark", "release_date"])
        .assign(
            expanding_rank=lambda df: df.groupby("benchmark")["score"]
            .expanding()
            .rank(ascending=False, method="max")
            .reset_index(level=0, drop=True)
        )
        .loc[lambda df: df["expanding_rank"] <= top_n]
        .drop(columns=["expanding_rank"])
        .reset_index(drop=True)
    )

    # Prepare necessary columns for modeling
    dataset = dataset.assign(
        days=lambda df: (
            df["release_date"]
            - df.groupby("benchmark")["release_date"].transform("min")
        ).dt.days
    ).assign(
        days_mid=lambda df: (df.groupby("benchmark")["days"].transform("max") / 2.0)
    )

    # Encode benchmark names as indices for pymc coords
    # Use `benchmark_idx` to index the dataset within the model
    benchmark_idx, benchmark_names = pd.factorize(dataset["benchmark"], sort=True)
    dataset["benchmark_idx"] = benchmark_idx
    coords = {
        "benchmark": benchmark_names,
        "obs": np.arange(len(dataset)),
    }

    with pm.Model(coords=coords) as model:
        # Upper asymptote
        L_min = 0.75
        L_max = 1.0
        L_range = L_max - L_min
        L_raw_mu = pm.Beta(
            "L_raw_mu", mu=(0.96 - L_min) / L_range, sigma=0.02 / L_range
        )
        L_raw_sigma = pm.HalfNormal("L_raw_sigma", sigma=0.02 / L_range)
        L_raw = pm.Beta("L_raw", mu=L_raw_mu, sigma=L_raw_sigma, dims="benchmark")
        L = pm.Deterministic("L", L_min + L_range * L_raw, dims="benchmark")

        # Lower bound
        l = pm.Data(
            "l",
            dataset["lower_bound"].groupby(dataset["benchmark_idx"]).first().values,
            dims="benchmark",
        )

        # Inflection point
        days_mid = dataset["days_mid"].groupby(dataset["benchmark_idx"]).first().values
        tau = pm.Gumbel("tau", mu=days_mid, beta=365 * 2, dims="benchmark")

        # Timestamps
        t_obs = pm.Data("t_obs", dataset["days"].values, dims="obs")
        idx_obs = pm.Data("idx_obs", dataset["benchmark_idx"].values, dims="obs")

        # Growth rate
        k_mu = pm.Gamma("k_mu", mu=0.005, sigma=0.002)
        k_sigma = pm.HalfNormal("k_sigma", sigma=0.005)
        k = pm.Gamma("k", mu=k_mu, sigma=k_sigma, dims="benchmark")

        # Mean latent performance
        logits = k[idx_obs] * (t_obs - tau[idx_obs])
        if sigmoid_kind == "logistic":
            sigmoid = pm.math.sigmoid(logits)
        elif sigmoid_kind == "harvey":
            alpha_raw_mu = pm.Gamma("alpha_raw_mu", mu=1.5, sigma=0.5)
            alpha_raw_sigma = pm.HalfNormal("alpha_raw_sigma", sigma=0.5)
            alpha_raw = pm.Gamma(
                "alpha_raw", mu=alpha_raw_mu, sigma=alpha_raw_sigma, dims="benchmark"
            )
            alpha = pm.Deterministic("alpha", alpha_raw + 1.0, dims="benchmark")
            base = pm.math.maximum(
                1 - (1 - alpha[idx_obs]) * pm.math.exp(-logits), 1e-10
            )
            sigmoid = pm.math.exp(1 / (1 - alpha[idx_obs]) * pm.math.log(base))
        else:
            raise ValueError(f"Unsupported sigmoid type: {sigmoid_kind}")
        mu = l[idx_obs] + (L[idx_obs] - l[idx_obs]) * sigmoid

        # Noise
        xi_base_mu = pm.Gamma("xi_base_mu", mu=0.05 + top_n / 50, sigma=0.02)
        xi_base_sigma = pm.HalfNormal("xi_base_sigma", sigma=0.05)
        xi_base = pm.Gamma(
            "xi_base", mu=xi_base_mu, sigma=xi_base_sigma, dims="benchmark"
        )
        variance_shape = pm.math.sqrt((mu - l[idx_obs]) * (L[idx_obs] - mu))
        max_variance = (L[idx_obs] - l[idx_obs]) / 2.0
        noise_factor = variance_shape / pm.math.maximum(max_variance, 1e-10)
        xi_0 = 0.01
        xi = xi_0 + xi_base[idx_obs] * noise_factor

        # Skewness
        s_mu = pm.Normal("s_mu", mu=-2 - top_n / 2, sigma=0.5)
        s_sigma = pm.HalfNormal("s_sigma", sigma=1)
        s = pm.TruncatedNormal("s", mu=s_mu, sigma=s_sigma, upper=0, dims="benchmark")

        # Observations
        y = pm.SkewNormal(
            "y",
            mu=mu,
            sigma=xi,
            alpha=s[idx_obs],
            observed=dataset["score"].values,
            dims="obs",
        )

        # Sample from the posterior
        idata = pm.sample(
            n_samples,
            tune=n_tune,
            return_inferencedata=True,
            random_seed=42,
            target_accept=0.9,
            init="adapt_diag",
            progressbar=True,
        )

    return idata


In [41]:
dataset = (
    pd.read_csv("benchmark_data_processed/all_normalized_updated_benchmarks.csv")
    .astype(
        {
            "benchmark": "string",
            "release_date": "datetime64[ns]",
            "score": "float64",
            "lower_bound": "float64",
        }
    )
    .dropna(subset=["benchmark", "release_date", "score", "lower_bound"])
)
idata = fit_model(dataset, top_n=3, sigmoid_kind="harvey")

Initializing NUTS using adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [L_raw_mu, L_raw_sigma, L_raw, tau, k_mu, k_sigma, k, alpha_raw_mu, alpha_raw_sigma, alpha_raw, xi_base_mu, xi_base_sigma, xi_base, s_mu, s_sigma, s]


Output()

Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 132 seconds.
There were 105 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details


# PyMC logistic forecast

In [None]:
# Forecast parameters
n_samples = 2000
n_tune = 1000
top_n = 3
forecast_days = 1523
forecast_type = "independent"  # Type of forecast model

# Create output directories
os.makedirs("Images", exist_ok=True)
os.makedirs("Fits", exist_ok=True)


def logistic(t, L, k, t0):
    """
    Logistic function: L / (1 + exp(-k * (t - t0)))

    Args:
        t: time (numeric)
        L: asymptote (maximum value)
        k: growth rate
        t0: inflection point (time at which value = L/2)
    """
    return L / (1 + pm.math.exp(-k * (t - t0)))


def extract_frontier_improvements(df_long, top_n=1):
    """
    Extract frontier improvements from benchmark data.

    For top_n=1: tracks the absolute best score over time (frontier).
    For top_n>1: tracks models that were in top-N at their release date.

    Args:
        df_long: DataFrame with columns [date, model_id, avg_score]
        top_n: number of top models to track

    Returns:
        List of dicts with keys: date, score, model_id, rank
    """
    df_sorted = df_long.sort_values("date")
    frontier_improvements = []

    if top_n == 1:
        # Original frontier logic (best score over time)
        best_score = float("-inf")

        for _, row in df_sorted.iterrows():
            if row["avg_score"] > best_score:
                best_score = row["avg_score"]
                frontier_improvements.append(
                    {
                        "date": row["date"],
                        "score": row["avg_score"],
                        "model_id": row["model_id"],
                        "rank": 1,
                    }
                )
    else:
        # Track models that were in top-N at their release date
        for release_date in df_sorted["date"].unique():
            # Models released on this date
            released_today = df_sorted[df_sorted["date"] == release_date].copy()

            # All scores up to and including today
            all_scores_now = df_sorted[df_sorted["date"] <= release_date].copy()

            # Get best score per model up to today
            best_at_date = (
                all_scores_now.groupby("model_id")["avg_score"].max().reset_index()
            )

            # Get earliest date for each model at their best score
            earliest_dates = all_scores_now.loc[
                all_scores_now.groupby("model_id")["avg_score"].idxmax()
            ][["model_id", "date"]].rename(columns={"date": "first_date"})

            best_at_date = best_at_date.merge(earliest_dates, on="model_id")

            # Sort: score (desc), date (asc), model_id (asc)
            best_at_date = best_at_date.sort_values(
                ["avg_score", "first_date", "model_id"], ascending=[False, True, True]
            ).reset_index(drop=True)

            # Assign ranks (same score = same rank)
            best_at_date["rank"] = (
                best_at_date["avg_score"]
                .rank(method="dense", ascending=False)
                .astype(int)
            )

            # For each model released TODAY, check if it's in top-N
            for _, row in released_today.iterrows():
                model_id = row["model_id"]
                model_rank_row = best_at_date[best_at_date["model_id"] == model_id]

                if not model_rank_row.empty:
                    rank = model_rank_row.iloc[0]["rank"]

                    if rank <= top_n:
                        frontier_improvements.append(
                            {
                                "date": row["date"],
                                "score": row["avg_score"],
                                "model_id": model_id,
                                "rank": rank,
                            }
                        )

    return frontier_improvements


def fit_logistic_benchmark(
    df_long,
    task_name,
    n_samples=n_samples,
    n_tune=n_tune,
    top_n=top_n,
    forecast_type=forecast_type,
    save_dir="Fits",
    lower_bounds_dict=None,
):
    """
    Fit a logistic curve to benchmark data.

    Args:
        df_long: DataFrame with columns [date, model_id, avg_score]
        task_name: name of the benchmark (for plotting)
        n_samples: number of MCMC samples
        n_tune: number of tuning steps
        top_n: number of top models to track (1 = frontier only)
        forecast_type: type of forecast model (e.g., "independent")
        save_dir: directory to save inference data
        lower_bounds_dict: dict mapping benchmark names to lower bounds (0-1 scale)

    Returns:
        idata: InferenceData object with posterior samples
        frontier_df: DataFrame with frontier/top-N improvements
        model: PyMC model object
    """
    # Extract frontier improvements
    frontier_improvements = extract_frontier_improvements(df_long, top_n=top_n)
    frontier_df = pd.DataFrame(frontier_improvements)

    # Convert dates to numeric (days since first observation)
    frontier_df["date_dt"] = pd.to_datetime(frontier_df["date"])
    min_date = frontier_df["date_dt"].min()
    frontier_df["days"] = (frontier_df["date_dt"] - min_date).dt.days

    # Prepare data for PyMC
    t_obs = frontier_df["days"].values.astype(float)
    y_obs = frontier_df["score"].values.astype(float)

    # Get lower bound for this benchmark (default to 0 if not available)
    if lower_bounds_dict is not None and task_name in lower_bounds_dict:
        lower_bound = lower_bounds_dict[task_name]
        # Handle NaN values
        if pd.isna(lower_bound):
            lower_bound = 0.0
            print(f"  Lower bound for {task_name}: NA (using 0.0)")
        else:
            print(f"  Lower bound for {task_name}: {lower_bound:.1%}")
    else:
        lower_bound = 0.0
        print(f"  Lower bound for {task_name}: not found (using 0.0)")

    # Build PyMC model with SHIFTED LOGISTIC
    with pm.Model() as model:
        # Priors
        # L: upper asymptote (Beta prior, must be > L_min)
        # Model the range: L ∈ [L_min, 1]
        L_min = 0.75
        available_range = 1.0 - L_min

        # Target upper asymptote around 0.95 in absolute terms
        target_L_abs = 0.95

        # Mean and sd on the raw (0–1) scale
        L_raw_mu = (target_L_abs - L_min) / available_range
        L_raw_sigma = 0.03 / available_range

        # Upper asymptote (using shifted Beta on raw scale)
        L_raw = pm.Beta("L_raw", mu=L_raw_mu, sigma=L_raw_sigma)
        L = pm.Deterministic("L", L_min + available_range * L_raw)

        # k: growth rate
        k = pm.Gamma("k", mu=0.01, sigma=0.008)

        # t0: inflection point
        t_mid = (t_obs.min() + t_obs.max()) / 2
        t0 = pm.Gumbel("t0", mu=t_mid, beta=365 * 2)

        # Expected value: SHIFTED LOGISTIC
        # y = lower_bound + (L - lower_bound) / (1 + exp(-k * (t - t0)))
        logistic_01 = 1.0 / (1 + pm.math.exp(-k * (t_obs - t0)))
        mu = lower_bound + (L - lower_bound) * logistic_01

        # Heteroskedastic noise model: Beta-like pattern
        # Base noise level (more noise for higher top_n)
        xi_base = pm.Gamma("xi_base", mu=0.05 + top_n / 100, sigma=0.035 + top_n / 200)

        # Beta-like variance pattern: √((μ - lower_bound) × (L - μ))
        # This accounts for the shifted range [lower_bound, L]
        variance_shape = pm.math.sqrt((mu - lower_bound) * (L - mu))

        # Normalized noise factor (peaks at 1.0 at inflection point)
        # Maximum variance is at midpoint: (L - lower_bound) / 2
        max_variance = (L - lower_bound) / 2.0
        noise_factor = variance_shape / pm.math.maximum(
            max_variance, 1e-10
        )  # Avoid division by zero

        # Skewness parameter (more asymmetry for higher top_n)
        s = pm.TruncatedNormal(
            "s", mu=-1.5 - top_n / 2, sigma=0.375 + top_n / 8, upper=0
        )

        # Likelihood
        y = pm.SkewNormal(
            "y", mu=mu, sigma=0.01 + xi_base * noise_factor, alpha=s, observed=y_obs
        )

        # Sample
        idata = pm.sample(
            n_samples,
            tune=n_tune,
            return_inferencedata=True,
            random_seed=42,
            target_accept=0.9,
            progressbar=False,
        )

    # Save results
    safe_task_name = task_name.replace("/", "_").replace(" ", "_").replace(".", "_")

    # Save inference data (NetCDF format - can be loaded with az.from_netcdf())
    idata_path = os.path.join(
        save_dir, f"{safe_task_name}_{forecast_type}_top{top_n}_idata.nc"
    )
    if os.path.exists(
        idata_path
    ):  # Remove existing file to ensure correct case on macOS
        os.remove(idata_path)
    idata.to_netcdf(idata_path)
    print(f"Saved inference data to: {idata_path}")

    # Save frontier data (CSV)
    frontier_csv_path = os.path.join(
        save_dir, f"{safe_task_name}_{forecast_type}_top{top_n}_frontier.csv"
    )
    if os.path.exists(frontier_csv_path):
        os.remove(frontier_csv_path)
    frontier_df.to_csv(frontier_csv_path, index=False)
    print(f"Saved frontier data to: {frontier_csv_path}")

    # Save metadata object (pickle) - excludes model which can't be pickled
    metadata = {
        "task_name": task_name,
        "top_n": top_n,
        "forecast_type": forecast_type,
        "n_samples": n_samples,
        "n_tune": n_tune,
        "idata_path": idata_path,
        "frontier_csv_path": frontier_csv_path,
        "heteroskedastic": True,
        "noise_model": "beta_like",
        "lower_bound": lower_bound,
    }
    metadata_path = os.path.join(
        save_dir, f"{safe_task_name}_{forecast_type}_top{top_n}_metadata.pkl"
    )
    if os.path.exists(metadata_path):
        os.remove(metadata_path)
    with open(metadata_path, "wb") as f:
        pickle.dump(metadata, f)
    print(f"Saved metadata to: {metadata_path}")

    return idata, frontier_df, model


def plot_logistic_fit(
    idata,
    frontier_df,
    task_name,
    forecast_days=forecast_days,
    top_n=top_n,
    forecast_type=forecast_type,
    save_dir="Images",
    lower_bound=0.0,
):
    """
    Plot the logistic fit with uncertainty bands.

    Args:
        idata: InferenceData from PyMC sampling
        frontier_df: DataFrame with frontier data and 'days' column
        task_name: name of the benchmark
        forecast_days: number of days to forecast beyond last observation
        top_n: number of top models tracked (for title)
        forecast_type: type of forecast model (for filename)
        save_dir: directory to save plots
        lower_bound: lower asymptote (random chance baseline)
    """
    # Extract posterior samples
    posterior = idata.posterior
    L_samples = posterior["L"].values.flatten()
    k_samples = posterior["k"].values.flatten()
    t0_samples = posterior["t0"].values.flatten()
    xi_base_samples = posterior["xi_base"].values.flatten()
    s_samples = posterior["s"].values.flatten()

    # Prepare time grid for plotting (observed + forecast)
    t_obs = frontier_df["days"].values
    t_min = t_obs.min()
    t_max = t_obs.max() + forecast_days
    t_grid = np.linspace(t_min, t_max, 200)

    # Calculate predictions for each posterior sample (SHIFTED logistic curve)
    n_samples = len(L_samples)
    predictions = np.zeros((n_samples, len(t_grid)))
    xi_grid = np.zeros((n_samples, len(t_grid)))

    for i in range(n_samples):
        # Shifted logistic curve
        logistic_01 = 1.0 / (1 + np.exp(-k_samples[i] * (t_grid - t0_samples[i])))
        mu_i = lower_bound + (L_samples[i] - lower_bound) * logistic_01
        predictions[i] = mu_i

        # Recalculate heteroskedastic sigma for shifted range
        variance_shape = np.sqrt((mu_i - lower_bound) * (L_samples[i] - mu_i))
        max_variance = (L_samples[i] - lower_bound) / 2.0
        noise_factor = variance_shape / np.maximum(
            max_variance, 1e-10
        )  # Avoid division by zero
        xi_grid[i] = 0.01 + xi_base_samples[i] * noise_factor

    # Calculate percentiles for logistic curve
    median_pred = np.percentile(predictions, 50, axis=0)
    lower_50_logistic = np.percentile(predictions, 25, axis=0)
    upper_50_logistic = np.percentile(predictions, 75, axis=0)
    lower_95_logistic = np.percentile(predictions, 2.5, axis=0)
    upper_95_logistic = np.percentile(predictions, 97.5, axis=0)

    # Prediction intervals with observation noise (vectorized)
    from scipy.stats import skewnorm

    lower_95_sampling = np.zeros(len(t_grid))
    upper_95_sampling = np.zeros(len(t_grid))

    for j in range(len(t_grid)):
        # Vectorized: one draw per posterior sample
        samples = skewnorm.rvs(s_samples, loc=predictions[:, j], scale=xi_grid[:, j])
        samples = np.clip(samples, 0, 1)
        lower_95_sampling[j] = np.percentile(samples, 2.5)
        upper_95_sampling[j] = np.percentile(samples, 97.5)

    # Convert days back to dates for x-axis
    min_date = frontier_df["date"].min()
    dates_grid = [min_date + timedelta(days=int(d)) for d in t_grid]

    # Plot
    fig, ax = plt.subplots(figsize=(12, 6))

    # Add horizontal line for lower bound if > 0
    if lower_bound > 0.01:
        ax.axhline(
            lower_bound,
            color="gray",
            linestyle=":",
            alpha=0.5,
            linewidth=1.5,
            label=f"Random chance ({lower_bound:.1%})",
        )

    # Predicted sampling intervals (wider, includes observation noise)
    ax.fill_between(
        dates_grid,
        lower_95_sampling,
        upper_95_sampling,
        alpha=0.1,
        color="#F18F01",
        label="95% prediction interval",
    )

    # Logistic uncertainty bands (narrower, just curve uncertainty)
    ax.fill_between(
        dates_grid,
        lower_95_logistic,
        upper_95_logistic,
        alpha=0.2,
        color="#2E86AB",
        label="95% CI (logistic)",
    )
    ax.fill_between(
        dates_grid,
        lower_50_logistic,
        upper_50_logistic,
        alpha=0.3,
        color="#2E86AB",
        label="50% CI (logistic)",
    )

    # Median prediction
    ax.plot(
        dates_grid,
        median_pred,
        "-",
        linewidth=2,
        color="#2E86AB",
        label="Median prediction",
    )

    # Observed points - color by rank if top_n > 1
    if top_n == 1:
        ax.plot(
            frontier_df["date"],
            frontier_df["score"],
            "o",
            markersize=8,
            color="#A23B72",
            label="Observed frontier",
            zorder=10,
        )
    else:
        # Plot different ranks with different colors/sizes
        colors = plt.cm.RdYlBu_r(np.linspace(0.2, 0.8, top_n))
        for rank in range(1, top_n + 1):
            rank_data = frontier_df[frontier_df["rank"] == rank]
            if not rank_data.empty:
                ax.plot(
                    rank_data["date"],
                    rank_data["score"],
                    "o",
                    markersize=10 - (rank - 1) * 1.0,
                    color=colors[rank - 1],
                    label=f"Top-{rank} at release",
                    alpha=0.8,
                    zorder=10 - rank,
                )

    # Vertical line at last observation
    last_date = frontier_df["date"].max()
    ax.axvline(last_date, color="gray", linestyle="--", alpha=0.5, linewidth=1)
    ax.text(
        last_date,
        ax.get_ylim()[0] + 0.05,
        "Last obs",
        rotation=90,
        verticalalignment="bottom",
        fontsize=9,
        color="gray",
    )

    # Formatting
    ax.set_xlabel("Date", fontsize=12, fontweight="bold")
    ax.set_ylabel("Score", fontsize=12, fontweight="bold")

    title_suffix = " (Frontier)" if top_n == 1 else f" (Top-{top_n} at release)"
    ax.set_title(
        f"Logistic Growth Forecast: {task_name}{title_suffix}",
        fontsize=14,
        fontweight="bold",
    )
    ax.grid(True, alpha=0.3)
    ax.legend(loc="lower right", fontsize=9)

    # Set y-axis limits based on lower bound
    # y_min = max(0, lower_bound - 0.05)
    # ax.set_ylim(y_min, 1.05)
    ax.set_ylim(0, 1.05)

    # Format x-axis
    import matplotlib.dates as mdates

    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
    ax.xaxis.set_major_locator(mdates.YearLocator())
    plt.xticks(rotation=45, ha="right")

    plt.tight_layout()

    # Save plot
    safe_task_name = task_name.replace("/", "_").replace(" ", "_").replace(".", "_")
    plot_path = os.path.join(
        save_dir, f"{safe_task_name}_{forecast_type}_top{top_n}_forecast.png"
    )
    if os.path.exists(
        plot_path
    ):  # Remove existing file to ensure correct case on macOS
        os.remove(plot_path)
    plt.savefig(plot_path, dpi=300, bbox_inches="tight")
    print(f"Saved plot to: {plot_path}")

    plt.show()

    # Print summary statistics
    print(f"\n{'=' * 80}")
    title_text = (
        f"FORECAST SUMMARY: {task_name} (Top-{top_n} at release)"
        if top_n > 1
        else f"FORECAST SUMMARY: {task_name} (Frontier)"
    )
    print(title_text)
    print(f"{'=' * 80}")
    print(f"Data points: {len(frontier_df)}")
    print(f"Unique models: {frontier_df['model_id'].nunique()}")
    print(f"Lower asymptote (random chance): {lower_bound:.3f}")
    print(f"Current best score: {frontier_df['score'].max():.3f}")
    print(
        f"Predicted upper asymptote (L): {np.median(L_samples):.3f} [{np.percentile(L_samples, 2.5):.3f}, {np.percentile(L_samples, 97.5):.3f}]"
    )
    print(
        f"Growth rate (k): {np.median(k_samples):.4f} [{np.percentile(k_samples, 2.5):.4f}, {np.percentile(k_samples, 97.5):.4f}]"
    )
    print(f"Inflection point (t0): {np.median(t0_samples):.0f} days")

    # Time to reach milestones
    current_score = frontier_df["score"].max()
    current_days = frontier_df["days"].max()

    for target in [0.9, 0.95, 0.99]:
        if target > current_score:
            # For each posterior sample, calculate when target is reached
            days_to_target = []
            for i in range(n_samples):
                # Solve: target =  L / (1 + exp(-k * (t - t0)))
                # target / L = 1 / (1 + exp(-k * (t - t0)))
                # t = t0 - log(L/target - 1) / k
                ratio = target / L_samples[i]
                if 0 < ratio < 1:  # valid range
                    t_target = t0_samples[i] - np.log(1 / ratio - 1) / k_samples[i]
                    days_from_now = t_target - current_days
                    if days_from_now > 0:
                        days_to_target.append(days_from_now)

            if days_to_target:
                median_days = np.median(days_to_target)
                lower_days = np.percentile(days_to_target, 2.5)
                upper_days = np.percentile(days_to_target, 97.5)
                target_date = last_date + timedelta(days=int(median_days))
                print(
                    f"\nTime to reach {target:.0%}: {median_days / 365:.1f} years "
                    + f"[{lower_days / 365:.1f}, {upper_days / 365:.1f}] (ETA: {target_date.strftime('%Y-%m')})"
                )


def load_fit(task_name, forecast_type=forecast_type, top_n=top_n, load_dir="Fits"):
    """
    Load a previously saved fit object.

    Args:
        task_name: name of the benchmark
        forecast_type: type of forecast model
        top_n: number of top models
        load_dir: directory containing saved fits

    Returns:
        Dictionary containing idata, frontier_df, and metadata
    """
    safe_task_name = task_name.replace("/", "_").replace(" ", "_").replace(".", "_")

    # Load inference data
    idata_path = os.path.join(
        load_dir, f"{safe_task_name}_{forecast_type}_top{top_n}_idata.nc"
    )
    idata = az.from_netcdf(idata_path)

    # Load frontier data
    frontier_csv_path = os.path.join(
        load_dir, f"{safe_task_name}_{forecast_type}_top{top_n}_frontier.csv"
    )
    frontier_df = pd.read_csv(frontier_csv_path)

    # Convert date column back to date objects
    frontier_df["date"] = pd.to_datetime(frontier_df["date"]).dt.date
    if "date_dt" in frontier_df.columns:
        frontier_df["date_dt"] = pd.to_datetime(frontier_df["date_dt"])

    # Load metadata
    metadata_path = os.path.join(
        load_dir, f"{safe_task_name}_{forecast_type}_top{top_n}_metadata.pkl"
    )
    with open(metadata_path, "rb") as f:
        metadata = pickle.load(f)

    print(f"Loaded fit for: {task_name}")
    print(f"  - Inference data: {idata_path}")
    print(f"  - Frontier data: {frontier_csv_path}")
    print(f"  - Metadata: {metadata_path}")

    fit_object = {
        "idata": idata,
        "frontier_df": frontier_df,
        "metadata": metadata,
        "task_name": task_name,
        "top_n": top_n,
        "forecast_type": forecast_type,
    }

    return fit_object

## Independent forecasts

### Internal Benchmarks

In [None]:
# Run fit for all internal benchmarks
internal_matrices, external_matrices = load_all_benchmark_matrices()

print(f"Found {len(internal_matrices)} internal benchmarks")
print(f"Found {len(external_matrices)} external benchmarks")

# Fit logistic models to internal benchmarks
results = {}

for task_name, data in internal_matrices.items():
    print(f"\n{'=' * 80}")
    print(f"Fitting: {task_name}")
    print(f"{'=' * 80}")

    df_long = data["df_long"]

    if len(df_long) < 3:
        print(f"Skipping {task_name}: insufficient data ({len(df_long)} observations)")
        continue

    try:
        idata, frontier_df, model = fit_logistic_benchmark(
            df_long, task_name, lower_bounds_dict=lower_bounds_dict
        )

        # Get lower bound for plotting
        lower_bound = 0.0
        if lower_bounds_dict and task_name in lower_bounds_dict:
            lb = lower_bounds_dict[task_name]
            if not pd.isna(lb):
                lower_bound = lb

        results[task_name] = {
            "idata": idata,
            "frontier_df": frontier_df,
            "model": model,
        }

        # Plot
        plot_logistic_fit(idata, frontier_df, task_name, lower_bound=lower_bound)

        # Print diagnostics
        print("\nMCMC Diagnostics:")
        print(az.summary(idata, var_names=["L", "k", "t0", "xi_base", "s"]))

    except Exception as e:
        print(f"Error fitting {task_name}: {e}")
        import traceback

        traceback.print_exc()

### External benchmarks

In [None]:
for task_name, data in external_matrices.items():
    print(f"\n{'=' * 80}")
    print(f"Fitting: {task_name}")
    print(f"{'=' * 80}")

    df_long = data["df_long"]

    if len(df_long) < 3:
        print(f"Skipping {task_name}: insufficient data ({len(df_long)} observations)")
        continue

    try:
        idata, frontier_df, model = fit_logistic_benchmark(
            df_long, task_name, lower_bounds_dict=lower_bounds_dict
        )

        # Get lower bound for plotting
        lower_bound = 0.0
        if lower_bounds_dict and task_name in lower_bounds_dict:
            lb = lower_bounds_dict[task_name]
            if not pd.isna(lb):
                lower_bound = lb

        results[task_name] = {
            "idata": idata,
            "frontier_df": frontier_df,
            "model": model,
        }

        # Plot
        plot_logistic_fit(idata, frontier_df, task_name, lower_bound=lower_bound)

        # Print diagnostics
        print("\nMCMC Diagnostics:")
        print(az.summary(idata, var_names=["L", "k", "t0", "xi_base", "s"]))

    except Exception as e:
        print(f"Error fitting {task_name}: {e}")
        import traceback

        traceback.print_exc()

print(f"\n{'=' * 80}")
print(f"Successfully fit {len(results)} / {len(external_matrices)} benchmarks")
print(f"{'=' * 80}")

## Joint hyperparameters

In [None]:
# ============================================================================
# JOINT HYPERPARAMETER MODEL
# ============================================================================


def fit_logistic_joint_hyperparameters(
    benchmark_dict,
    n_samples=2000,
    n_tune=1000,
    top_n=3,
    forecast_type="joint_hyperparameters",
    save_dir="Fits",
    lower_bounds_dict=None,
):
    """
    Fit logistic curves to multiple benchmarks with shared hyperparameters.

    This allows benchmarks to inform each other through common priors on:
    - L_mu, L_sigma: asymptote distribution parameters
    - k_mu, k_sigma: growth rate distribution parameters
    - xi_base_mu, xi_base_sigma: noise level distribution parameters
    - s_mu, s_sigma: skewness distribution parameters

    Args:
        benchmark_dict: dict of {task_name: df_long} for multiple benchmarks
        n_samples: number of MCMC samples
        n_tune: number of tuning steps
        top_n: number of top models to track (fixed, not learned)
        forecast_type: type of forecast model
        save_dir: directory to save inference data
        lower_bounds_dict: dict mapping benchmark names to lower bounds (0-1 scale)

    Returns:
        idata: InferenceData with posterior samples for all benchmarks
        frontier_dfs: dict of {task_name: frontier_df}
        model: PyMC model object
    """
    # Prepare data for all benchmarks
    task_names = []
    all_t_obs = []
    all_y_obs = []
    all_lower_bounds = []
    frontier_dfs = {}

    for task_idx, (task_name, df_long) in enumerate(benchmark_dict.items()):
        # Extract frontier (same logic as independent model)
        frontier_improvements = extract_frontier_improvements(df_long, top_n=top_n)

        frontier_df = pd.DataFrame(frontier_improvements)
        frontier_df["date_dt"] = pd.to_datetime(frontier_df["date"])
        min_date = frontier_df["date_dt"].min()
        frontier_df["days"] = (frontier_df["date_dt"] - min_date).dt.days

        # Get lower bound for this benchmark
        if lower_bounds_dict is not None and task_name in lower_bounds_dict:
            lower_bound = lower_bounds_dict[task_name]
            if pd.isna(lower_bound):
                lower_bound = 0.0
                print(f"  Lower bound for {task_name}: NA (using 0.0)")
            else:
                print(f"  Lower bound for {task_name}: {lower_bound:.1%}")
        else:
            lower_bound = 0.0
            print(f"  Lower bound for {task_name}: not found (using 0.0)")

        # Store frontier data
        frontier_dfs[task_name] = frontier_df
        task_names.append(task_name)
        all_lower_bounds.append(lower_bound)

        # Prepare PyMC data
        t_obs = frontier_df["days"].values.astype(float)
        y_obs = frontier_df["score"].values.astype(float)

        all_t_obs.append(t_obs)
        all_y_obs.append(y_obs)

    n_tasks = len(task_names)
    all_lower_bounds = np.array(all_lower_bounds)

    # ========================================================================
    # VECTORIZATION: Concatenate all observations with task indices
    # ========================================================================

    # Create task indices for each observation
    task_indices = []
    for task_idx in range(n_tasks):
        task_indices.extend([task_idx] * len(all_t_obs[task_idx]))
    task_indices = np.array(task_indices, dtype=np.int32)

    # Concatenate all observations
    t_obs_all = np.concatenate(all_t_obs)
    y_obs_all = np.concatenate(all_y_obs)
    n_obs_total = len(t_obs_all)

    # Lower bounds for each observation (indexed by task)
    lower_bounds_per_obs = all_lower_bounds[task_indices]

    # Calculate t_mids for each task
    t_mids = np.array(
        [(all_t_obs[i].min() + all_t_obs[i].max()) / 2 for i in range(n_tasks)]
    )

    print(f"\n{'=' * 80}")
    print(f"VECTORIZED MODEL SETUP")
    print(f"{'=' * 80}")
    print(f"Number of tasks: {n_tasks}")
    print(f"Total observations: {n_obs_total}")
    print(f"Observations per task: {[len(t) for t in all_t_obs]}")

    # Build joint hierarchical model with SHIFTED LOGISTIC (VECTORIZED)
    with pm.Model() as model:
        # ====================================================================
        # HYPERPRIORS (shared across all benchmarks)
        # ====================================================================

        # Asymptote (L) hyperparameters (defined on raw 0–1 scale, then shifted)
        L_min = 0.75  # Minimum allowed asymptote (absolute value)
        available_range = 1.0 - L_min

        # Hyperprior for the population mean and sd of L on the raw scale
        L_raw_mu = pm.Beta(
            "L_raw_mu",
            mu=(0.96 - L_min)
            / available_range,  # centers L_mu around 0.96 on absolute scale
            sigma=0.02 / available_range,  # corresponding sd on raw scale
        )
        L_raw_sigma = pm.HalfNormal("L_raw_sigma", sigma=0.03 / available_range)

        # Implied absolute-scale hyperparameters (for diagnostics)
        L_mu = pm.Deterministic("L_mu", L_min + available_range * L_raw_mu)
        L_sigma = pm.Deterministic("L_sigma", available_range * L_raw_sigma)

        # Growth rate (k) hyperparameters
        k_mu = pm.Gamma("k_mu", mu=0.005, sigma=0.002)
        k_sigma = pm.HalfNormal("k_sigma", sigma=0.005)

        # Base noise (xi_base) hyperparameters
        xi_base_mu = pm.Gamma("xi_base_mu", mu=0.05 + top_n / 50, sigma=0.02)
        xi_base_sigma = pm.HalfNormal("xi_base_sigma", sigma=0.05)

        # Skewness (s) hyperparameters
        s_mu = pm.Normal("s_mu", mu=-2 - top_n / 2, sigma=0.5)
        s_sigma = pm.HalfNormal("s_sigma", sigma=1)

        # ====================================================================
        # TASK-SPECIFIC PARAMETERS (informed by hyperpriors)
        # ====================================================================

        # Task-specific upper asymptotes (using shifted Beta on raw scale)
        L_raw = pm.Beta("L_raw", mu=L_raw_mu, sigma=L_raw_sigma, shape=n_tasks)
        L = pm.Deterministic("L", L_min + available_range * L_raw)

        # Task-specific growth rates
        k = pm.Gamma("k", mu=k_mu, sigma=k_sigma, shape=n_tasks)

        # Task-specific inflection points (t0)
        t0 = pm.Gumbel("t0", mu=t_mids, beta=365 * 2, shape=n_tasks)

        # Task-specific base noise
        xi_base = pm.Gamma("xi_base", mu=xi_base_mu, sigma=xi_base_sigma, shape=n_tasks)

        # Task-specific skewness
        s = pm.TruncatedNormal("s", mu=s_mu, sigma=s_sigma, upper=0, shape=n_tasks)

        # ====================================================================
        # VECTORIZED LIKELIHOOD
        # ====================================================================

        # Index task-specific parameters by observation
        L_obs = L[task_indices]
        k_obs = k[task_indices]
        t0_obs = t0[task_indices]
        xi_base_obs = xi_base[task_indices]
        s_obs = s[task_indices]
        lower_bounds_obs = pm.math.constant(lower_bounds_per_obs)

        # Shifted logistic curve (vectorized over all observations)
        logistic_01 = 1.0 / (1 + pm.math.exp(-k_obs * (t_obs_all - t0_obs)))
        mu_obs = lower_bounds_obs + (L_obs - lower_bounds_obs) * logistic_01

        # Heteroskedastic noise (Beta-like pattern for shifted range)
        variance_shape = pm.math.sqrt((mu_obs - lower_bounds_obs) * (L_obs - mu_obs))
        max_variance = (L_obs - lower_bounds_obs) / 2.0
        noise_factor = variance_shape / pm.math.maximum(
            max_variance, 1e-10
        )  # Avoid division by zero
        xi = 0.01 + xi_base_obs * noise_factor

        # Single vectorized likelihood for ALL observations
        y = pm.SkewNormal("y", mu=mu_obs, sigma=xi, alpha=s_obs, observed=y_obs_all)

        # ====================================================================
        # SAMPLING
        # ====================================================================

        print(f"Fitting vectorized joint model with {n_tasks} benchmarks...")
        print(f"Total observations: {n_obs_total}")

        idata = pm.sample(
            n_samples,
            tune=n_tune,
            return_inferencedata=True,
            random_seed=42,
            target_accept=0.9,
            init="adapt_diag",
            progressbar=True,
        )

    # ====================================================================
    # SAVE RESULTS
    # ====================================================================

    # Save joint inference data
    idata_path = os.path.join(save_dir, f"joint_{forecast_type}_top{top_n}_idata.nc")
    idata.to_netcdf(idata_path)
    print(f"\nSaved joint inference data to: {idata_path}")

    # Save individual frontier CSVs
    for task_name, frontier_df in frontier_dfs.items():
        safe_task_name = task_name.replace("/", "_").replace(" ", "_").replace(".", "_")
        frontier_csv_path = os.path.join(
            save_dir, f"{safe_task_name}_{forecast_type}_top{top_n}_frontier.csv"
        )
        if os.path.exists(frontier_csv_path):
            os.remove(frontier_csv_path)
        frontier_df.to_csv(frontier_csv_path, index=False)

    # Save metadata
    metadata = {
        "task_names": task_names,
        "n_tasks": n_tasks,
        "top_n": top_n,
        "forecast_type": forecast_type,
        "n_samples": n_samples,
        "n_tune": n_tune,
        "idata_path": idata_path,
        "heteroskedastic": True,
        "noise_model": "beta_like",
        "hierarchical": True,
        "vectorized": True,
        "lower_bounds": all_lower_bounds.tolist(),
    }
    metadata_path = os.path.join(
        save_dir, f"joint_{forecast_type}_top{top_n}_metadata.pkl"
    )
    if os.path.exists(metadata_path):
        os.remove(metadata_path)
    with open(metadata_path, "wb") as f:
        pickle.dump(metadata, f)
    print(f"Saved metadata to: {metadata_path}")

    return idata, frontier_dfs, model


def plot_logistic_fit_from_joint(
    idata,
    frontier_df,
    task_name,
    task_idx,
    forecast_days=1523,
    top_n=3,
    forecast_type="joint_hyperparameters",
    save_dir="Images",
    lower_bound=0.0,
):
    """
    Plot forecast for a single task from joint model fit.

    Args:
        idata: InferenceData from joint model
        frontier_df: DataFrame with frontier data for this task
        task_name: name of the benchmark
        task_idx: index of this task in the joint model
        forecast_days: number of days to forecast
        top_n: number of top models tracked
        forecast_type: type of forecast model
        save_dir: directory to save plots
        lower_bound: lower asymptote (random chance baseline)
    """
    # Extract posterior samples for THIS task
    posterior = idata.posterior
    L_samples = posterior["L"].sel(L_dim_0=task_idx).values.flatten()
    k_samples = posterior["k"].sel(k_dim_0=task_idx).values.flatten()
    t0_samples = posterior["t0"].sel(t0_dim_0=task_idx).values.flatten()
    xi_base_samples = posterior["xi_base"].sel(xi_base_dim_0=task_idx).values.flatten()
    s_samples = posterior["s"].sel(s_dim_0=task_idx).values.flatten()

    # Time grid
    t_obs = frontier_df["days"].values
    t_min = t_obs.min()
    t_max = t_obs.max() + forecast_days
    t_grid = np.linspace(t_min, t_max, 200)

    n_samples = len(L_samples)
    predictions = np.zeros((n_samples, len(t_grid)))
    xi_grid = np.zeros((n_samples, len(t_grid)))

    for i in range(n_samples):
        # Shifted logistic curve
        logistic_01 = 1.0 / (1 + np.exp(-k_samples[i] * (t_grid - t0_samples[i])))
        mu_i = lower_bound + (L_samples[i] - lower_bound) * logistic_01
        predictions[i] = mu_i

        # Heteroskedastic noise for shifted range
        variance_shape = np.sqrt((mu_i - lower_bound) * (L_samples[i] - mu_i))
        max_variance = (L_samples[i] - lower_bound) / 2.0
        noise_factor = variance_shape / np.maximum(
            max_variance, 1e-10
        )  # Avoid division by zero
        xi_grid[i] = 0.01 + xi_base_samples[i] * noise_factor

    median_pred = np.percentile(predictions, 50, axis=0)
    lower_50_logistic = np.percentile(predictions, 25, axis=0)
    upper_50_logistic = np.percentile(predictions, 75, axis=0)
    lower_95_logistic = np.percentile(predictions, 2.5, axis=0)
    upper_95_logistic = np.percentile(predictions, 97.5, axis=0)

    # Prediction intervals with observation noise (vectorized)
    from scipy.stats import skewnorm

    lower_95_sampling = np.zeros(len(t_grid))
    upper_95_sampling = np.zeros(len(t_grid))

    for j in range(len(t_grid)):
        # Vectorized: one draw per posterior sample
        samples = skewnorm.rvs(s_samples, loc=predictions[:, j], scale=xi_grid[:, j])
        samples = np.clip(samples, 0, 1)
        lower_95_sampling[j] = np.percentile(samples, 2.5)
        upper_95_sampling[j] = np.percentile(samples, 97.5)

    min_date = frontier_df["date"].min()
    dates_grid = [min_date + timedelta(days=int(d)) for d in t_grid]

    # Plot
    fig, ax = plt.subplots(figsize=(12, 6))

    # Add horizontal line for lower bound if > 0
    if lower_bound > 0.01:
        ax.axhline(
            lower_bound,
            color="gray",
            linestyle=":",
            alpha=0.5,
            linewidth=1.5,
            label=f"Random chance ({lower_bound:.1%})",
        )

    ax.fill_between(
        dates_grid,
        lower_95_sampling,
        upper_95_sampling,
        alpha=0.1,
        color="#F18F01",
        label="95% prediction interval",
    )
    ax.fill_between(
        dates_grid,
        lower_95_logistic,
        upper_95_logistic,
        alpha=0.2,
        color="#2E86AB",
        label="95% CI (logistic)",
    )
    ax.fill_between(
        dates_grid,
        lower_50_logistic,
        upper_50_logistic,
        alpha=0.3,
        color="#2E86AB",
        label="50% CI (logistic)",
    )
    ax.plot(
        dates_grid,
        median_pred,
        "-",
        linewidth=2,
        color="#2E86AB",
        label="Median prediction",
    )

    if top_n == 1:
        ax.plot(
            frontier_df["date"],
            frontier_df["score"],
            "o",
            markersize=8,
            color="#A23B72",
            label="Observed frontier",
            zorder=10,
        )
    else:
        colors = plt.cm.RdYlBu_r(np.linspace(0.2, 0.8, top_n))
        for rank in range(1, top_n + 1):
            rank_data = frontier_df[frontier_df["rank"] == rank]
            if not rank_data.empty:
                ax.plot(
                    rank_data["date"],
                    rank_data["score"],
                    "o",
                    markersize=10 - (rank - 1) * 1.0,
                    color=colors[rank - 1],
                    label=f"Top-{rank} at release",
                    alpha=0.8,
                    zorder=10 - rank,
                )

    last_date = frontier_df["date"].max()
    ax.axvline(last_date, color="gray", linestyle="--", alpha=0.5, linewidth=1)
    ax.text(
        last_date,
        ax.get_ylim()[0] + 0.05,
        "Last obs",
        rotation=90,
        verticalalignment="bottom",
        fontsize=9,
        color="gray",
    )

    ax.set_xlabel("Date", fontsize=12, fontweight="bold")
    ax.set_ylabel("Score", fontsize=12, fontweight="bold")

    title_suffix = " (Frontier, Joint)" if top_n == 1 else f" (Top-{top_n}, Joint)"
    ax.set_title(
        f"Logistic Growth Forecast: {task_name}{title_suffix}",
        fontsize=14,
        fontweight="bold",
    )
    ax.grid(True, alpha=0.3)
    ax.legend(loc="lower right", fontsize=9)
    ax.set_ylim(0, 1.05)

    import matplotlib.dates as mdates

    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m"))
    ax.xaxis.set_major_locator(mdates.YearLocator())
    plt.xticks(rotation=45, ha="right")

    plt.tight_layout()

    safe_task_name = task_name.replace("/", "_").replace(" ", "_").replace(".", "_")
    plot_path = os.path.join(
        save_dir, f"{safe_task_name}_{forecast_type}_top{top_n}_forecast.png"
    )
    plt.savefig(plot_path, dpi=300, bbox_inches="tight")
    if os.path.exists(
        plot_path
    ):  # Remove existing file to ensure correct case on macOS
        os.remove(plot_path)
    print(f"Saved plot to: {plot_path}")
    plt.show()

    # Print summary
    print(f"\n{'=' * 80}")
    title_text = (
        f"FORECAST SUMMARY: {task_name} (Joint Model, Top-{top_n})"
        if top_n > 1
        else f"FORECAST SUMMARY: {task_name} (Joint Model, Frontier)"
    )
    print(title_text)
    print(f"{'=' * 80}")
    print(f"Data points: {len(frontier_df)}")
    print(f"Unique models: {frontier_df['model_id'].nunique()}")
    print(f"Lower asymptote (random chance): {lower_bound:.3f}")
    print(f"Current best score: {frontier_df['score'].max():.3f}")
    print(
        f"Predicted upper asymptote (L): {np.median(L_samples):.3f} [{np.percentile(L_samples, 2.5):.3f}, {np.percentile(L_samples, 97.5):.3f}]"
    )
    print(
        f"Growth rate (k): {np.median(k_samples):.4f} [{np.percentile(k_samples, 2.5):.4f}, {np.percentile(k_samples, 97.5):.4f}]"
    )
    print(f"Inflection point (t0): {np.median(t0_samples):.0f} days")


def plot_joint_hyperparameters(idata, task_names, save_dir="Images"):
    """
    Plot the learned hyperparameter distributions.

    Args:
        idata: InferenceData from joint model
        task_names: list of task names
        save_dir: directory to save plots
    """
    posterior = idata.posterior

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # L hyperparameters
    ax = axes[0, 0]
    L_mu_samples = posterior["L_mu"].values.flatten()
    L_sigma_samples = posterior["L_sigma"].values.flatten()
    ax.hist2d(L_mu_samples, L_sigma_samples, bins=50, cmap="Blues")
    ax.set_xlabel("L_mu (asymptote mean)", fontweight="bold")
    ax.set_ylabel("L_sigma (asymptote std)", fontweight="bold")
    ax.set_title("Asymptote Hyperparameters")
    ax.grid(True, alpha=0.3)

    # k hyperparameters
    ax = axes[0, 1]
    k_mu_samples = posterior["k_mu"].values.flatten()
    k_sigma_samples = posterior["k_sigma"].values.flatten()
    ax.hist2d(k_mu_samples, k_sigma_samples, bins=50, cmap="Greens")
    ax.set_xlabel("k_mu (growth rate mean)", fontweight="bold")
    ax.set_ylabel("k_sigma (growth rate std)", fontweight="bold")
    ax.set_title("Growth Rate Hyperparameters")
    ax.grid(True, alpha=0.3)

    # xi_base hyperparameters
    ax = axes[1, 0]
    xi_base_mu_samples = posterior["xi_base_mu"].values.flatten()
    xi_base_sigma_samples = posterior["xi_base_sigma"].values.flatten()
    ax.hist2d(xi_base_mu_samples, xi_base_sigma_samples, bins=50, cmap="Oranges")
    ax.set_xlabel("xi_base_mu (noise mean)", fontweight="bold")
    ax.set_ylabel("xi_base_sigma (noise std)", fontweight="bold")
    ax.set_title("Base Noise Hyperparameters")
    ax.grid(True, alpha=0.3)

    # s hyperparameters
    ax = axes[1, 1]
    s_mu_samples = posterior["s_mu"].values.flatten()
    s_sigma_samples = posterior["s_sigma"].values.flatten()
    ax.hist2d(s_mu_samples, s_sigma_samples, bins=50, cmap="Reds")
    ax.set_xlabel("s_mu (skewness mean)", fontweight="bold")
    ax.set_ylabel("s_sigma (skewness std)", fontweight="bold")
    ax.set_title("Skewness Hyperparameters")
    ax.grid(True, alpha=0.3)

    plt.tight_layout()

    plot_path = os.path.join(save_dir, "0_joint_hyperparameters.png")
    if os.path.exists(
        plot_path
    ):  # Remove existing file to ensure correct case on macOS
        os.remove(plot_path)
    plt.savefig(plot_path, dpi=300, bbox_inches="tight")
    print(f"Saved hyperparameter plot to: {plot_path}")
    plt.show()

    # Print summary statistics
    print(f"\n{'=' * 80}")
    print("JOINT HYPERPARAMETER SUMMARY")
    print(f"{'=' * 80}")
    print(f"Number of benchmarks: {len(task_names)}")
    print(f"\nAsymptote (L):")
    print(
        f"  L_mu: {np.median(L_mu_samples):.3f} [{np.percentile(L_mu_samples, 2.5):.3f}, {np.percentile(L_mu_samples, 97.5):.3f}]"
    )
    print(
        f"  L_sigma: {np.median(L_sigma_samples):.3f} [{np.percentile(L_sigma_samples, 2.5):.3f}, {np.percentile(L_sigma_samples, 97.5):.3f}]"
    )
    print(f"\nGrowth rate (k):")
    print(
        f"  k_mu: {np.median(k_mu_samples):.4f} [{np.percentile(k_mu_samples, 2.5):.4f}, {np.percentile(k_mu_samples, 97.5):.4f}]"
    )
    print(
        f"  k_sigma: {np.median(k_sigma_samples):.4f} [{np.percentile(k_sigma_samples, 2.5):.4f}, {np.percentile(k_sigma_samples, 97.5):.4f}]"
    )
    print(f"\nBase noise (xi_base):")
    print(
        f"  xi_base_mu: {np.median(xi_base_mu_samples):.3f} [{np.percentile(xi_base_mu_samples, 2.5):.3f}, {np.percentile(xi_base_mu_samples, 97.5):.3f}]"
    )
    print(
        f"  xi_base_sigma: {np.median(xi_base_sigma_samples):.3f} [{np.percentile(xi_base_sigma_samples, 2.5):.3f}, {np.percentile(xi_base_sigma_samples, 97.5):.3f}]"
    )
    print(f"\nSkewness (s):")
    print(
        f"  s_mu: {np.median(s_mu_samples):.2f} [{np.percentile(s_mu_samples, 2.5):.2f}, {np.percentile(s_mu_samples, 97.5):.2f}]"
    )
    print(
        f"  s_sigma: {np.median(s_sigma_samples):.2f} [{np.percentile(s_sigma_samples, 2.5):.2f}, {np.percentile(s_sigma_samples, 97.5):.2f}]"
    )

### Internal benchmarks

In [None]:
# ============================================================================
# EXECUTION: Fit joint model to all internal benchmarks
# ============================================================================

# Select benchmarks with sufficient data
benchmark_dict = {}
for task_name, data in internal_matrices.items():
    df_long = data["df_long"]
    if len(df_long) >= 3:  # Minimum observations
        benchmark_dict[task_name] = df_long

print(f"Fitting joint model with {len(benchmark_dict)} internal benchmarks")

# Fit joint model
idata_joint, frontier_dfs, model_joint = fit_logistic_joint_hyperparameters(
    benchmark_dict,
    n_samples=2000,
    n_tune=1000,
    top_n=3,
    forecast_type="joint_hyperparameters_internalOnly",
    lower_bounds_dict=lower_bounds_dict,
)

# Plot learned hyperparameters
plot_joint_hyperparameters(idata_joint, list(benchmark_dict.keys()))

# Print MCMC diagnostics
print("\nJoint Model MCMC Diagnostics:")
print(
    az.summary(
        idata_joint,
        var_names=[
            "L_mu",
            "L_sigma",
            "k_mu",
            "k_sigma",
            "xi_base_mu",
            "xi_base_sigma",
            "s_mu",
            "s_sigma",
        ],
    )
)

In [None]:
# ============================================================================
# PLOT FORECASTS FOR INDIVIDUAL BENCHMARKS
# ============================================================================

for task_idx, (task_name, frontier_df) in enumerate(frontier_dfs.items()):
    print(f"\nPlotting: {task_name}")

    # Get lower bound for this benchmark
    lower_bound = 0.0
    if lower_bounds_dict and task_name in lower_bounds_dict:
        lb = lower_bounds_dict[task_name]
        if not pd.isna(lb):
            lower_bound = lb

    plot_logistic_fit_from_joint(
        idata_joint,
        frontier_df,
        task_name,
        task_idx,
        forecast_days=1523,
        top_n=3,
        forecast_type="joint_hyperparameters",
        lower_bound=lower_bound,
    )

print(f"\n{'=' * 80}")
print(f"Successfully fit joint model with {len(benchmark_dict)} benchmarks")
print(f"{'=' * 80}")

### All benchmarks

In [None]:
# ============================================================================
# EXECUTION: Fit joint model to ALL benchmarks (internal + external)
# ============================================================================

# Combine all benchmarks with sufficient data
benchmark_dict_all = {}

# Add internal benchmarks
for task_name, data in internal_matrices.items():
    df_long = data["df_long"]
    if len(df_long) >= 3:  # Minimum observations
        benchmark_dict_all[f"{task_name}"] = df_long

# Add external benchmarks
for task_name, data in external_matrices.items():
    df_long = data["df_long"]
    if len(df_long) >= 3:  # Minimum observations
        benchmark_dict_all[f"{task_name}"] = df_long

# Fit joint model
idata_joint_all, frontier_dfs_all, model_joint_all = fit_logistic_joint_hyperparameters(
    benchmark_dict_all,
    n_samples=2000,
    n_tune=1000,
    top_n=3,
    forecast_type="joint_hyperparameters_all",
    lower_bounds_dict=lower_bounds_dict,
)

# Plot learned hyperparameters
plot_joint_hyperparameters(idata_joint_all, list(benchmark_dict_all.keys()))

# Print MCMC diagnostics
print("\nJoint Model (All Benchmarks) MCMC Diagnostics:")
print(
    az.summary(
        idata_joint_all,
        var_names=[
            "L_mu",
            "L_sigma",
            "k_mu",
            "k_sigma",
            "xi_base_mu",
            "xi_base_sigma",
            "s_mu",
            "s_sigma",
        ],
    )
)

In [None]:
# ============================================================================
# PLOT FORECASTS FOR INDIVIDUAL BENCHMARKS (ALL)
# ============================================================================

for task_idx, (task_name, frontier_df) in enumerate(frontier_dfs_all.items()):
    print(f"\nPlotting: {task_name}")

    # Get lower bound for this benchmark
    lower_bound = 0.0
    if lower_bounds_dict and task_name in lower_bounds_dict:
        lb = lower_bounds_dict[task_name]
        if not pd.isna(lb):
            lower_bound = lb

    plot_logistic_fit_from_joint(
        idata_joint_all,
        frontier_df,
        task_name,
        task_idx,
        forecast_days=1523,
        top_n=3,
        forecast_type="joint_hyperparameters_all",
        lower_bound=lower_bound,
    )

print(f"\n{'=' * 80}")
print(f"Successfully fit joint model with {len(benchmark_dict_all)} benchmarks")
print(f"  - Internal: {sum(1 for k in frontier_dfs_all if k.startswith('internal_'))}")
print(f"  - External: {sum(1 for k in frontier_dfs_all if k.startswith('external_'))}")
print(f"{'=' * 80}")

# Other

## Gamma distribution tests

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import gamma  # Import gamma function from scipy


def gamma_params_from_mu_sigma(mu, sigma):
    """Convert from PyMC (mu, sigma) to (alpha, beta) parameterization"""
    alpha = (mu / sigma) ** 2
    beta = mu / (sigma**2)
    return alpha, beta


def gamma_pdf(x, mu, sigma):
    """Gamma PDF using PyMC parameterization"""
    alpha, beta = gamma_params_from_mu_sigma(mu, sigma)
    return (beta**alpha) * (x ** (alpha - 1)) * np.exp(-beta * x) / gamma(alpha)


# Parameters
mu = 0.1
sigmas = [0.06, 0.07, 0.08]  # different standard deviations
x = np.linspace(0, 0.3, 1000)

# Create plot
plt.figure(figsize=(12, 6))

# Plot each distribution
colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"]
for sigma, color in zip(sigmas, colors):
    plt.plot(x, gamma_pdf(x, mu, sigma), "-", color=color, lw=2, label=f"σ={sigma:.3f}")

    # Print implied parameters
    alpha, beta = gamma_params_from_mu_sigma(mu, sigma)
    print(f"σ={sigma:.3f}: α={alpha:.2f}, β={beta:.2f}")

# Add vertical line for mean
plt.axvline(mu, color="black", linestyle="--", alpha=0.5, label="Mean (μ)")

plt.xlabel("x")
plt.ylabel("Density")
plt.title(f"Gamma Distribution (PyMC parameterization) - Fixed μ={mu}")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()
alpha, beta = gamma_params_from_mu_sigma(mu, sigma)
print(f"Equivalent standard parameters: alpha={alpha:.2f}, beta={beta:.2f}")