<a target="_blank" href="https://colab.research.google.com/github/acciochris/stats-project-2025/blob/main/chi_squared_conf_int_demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [None]:
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from statistics import NormalDist
from dataclasses import dataclass


def get_intervals(probs: list[int]):
    intervals = []
    start = 0.0
    for prob in probs:
        intervals.append((start, start + prob))
        start += prob
    if not math.isclose(start, 1.0):
        raise ValueError("sum of probabilities must be equal to 1")
    return intervals


def random_counts(rng: np.random.Generator, probs: list[int], n: int):
    from string import ascii_uppercase

    if len(probs) > len(ascii_uppercase):
        raise ValueError("too many categories")
    labels = list(ascii_uppercase[: len(probs)])

    return pd.Series(
        pd.cut(
            rng.random((n,)),
            pd.IntervalIndex.from_tuples(
                get_intervals(probs),
                closed="left",
            ),
        ).rename_categories(labels)
    ).value_counts(sort=False)


def sampling_dist(sample_fn, trials=100):
    sample: pd.Series = sample_fn()
    n = sample.sum()
    freqs = {cat: list() for cat in sample.index}

    for _ in range(trials):
        sample = sample_fn()
        for cat, count in sample.items():
            freqs[cat].append(count)

    return pd.DataFrame(freqs) / n


def plot_df_hist(df: pd.DataFrame, *, bar_width=0.02):
    maxs = df.max()
    mins = df.min()

    fig, axs = plt.subplots(
        1,
        len(df.columns),
        figsize=(10, 5),
        sharey=True,
        tight_layout=True,
        width_ratios=maxs - mins,
    )

    for i, col in enumerate(df.columns):
        axs[i].hist(df[col], bins=np.arange(mins[col], maxs[col], bar_width))
        axs[i].set_title(col)
    fig.supxlabel("Sample Proportion")
    fig.supylabel("Counts")
    return fig


def plot_df_hist_overlay(df: pd.DataFrame, *, bar_width=0.02):
    df_max = df.max().max()
    df_min = df.min().min()
    bins = np.arange(df_min, df_max, bar_width)

    for col in df.columns:
        plt.hist(df[col], bins=bins, alpha=0.5, label=col)
    plt.legend()
    plt.xlabel("Sample Proportion")
    plt.ylabel("Counts")


def one_prop_z_int(props, n: int, conf: float = 0.95):
    norm = NormalDist()
    intervals = []
    z_critical = norm.inv_cdf((1 - conf) / 2)
    for prop in props:
        me = abs(z_critical * math.sqrt(prop * (1 - prop) / n))
        intervals.append((prop - me, prop + me))
    return intervals


def within_interval(
    props: pd.DataFrame, population_prop: list[float], n: int, conf: float = 0.95
):
    result = np.empty(props.shape, bool)
    for i, row in props.iterrows():
        intervals = one_prop_z_int(list(row), n=n, conf=conf)
        for j, (lo, hi) in enumerate(intervals):
            result[i][j] = lo < population_prop[j] < hi

    return pd.DataFrame(result, columns=props.columns)


@dataclass
class CategoricalPopulation:
    rng: np.random.Generator
    probs: list[int]
    n: int

    def one_sample(self):
        return random_counts(self.rng, self.probs, self.n)

    def sampling_dist(self, trials: int):
        return sampling_dist(self.one_sample, trials=trials)

    def within_interval_counts(
        self,
        conf: float,
        *,
        trials: int | None = None,
        dist: pd.DataFrame | None = None,
    ) -> pd.Series:
        if dist is None:
            if trials is None:
                raise ValueError("must provide either trials or dist")
            else:
                dist = self.sampling_dist(trials)

        return within_interval(
            props=dist, population_prop=self.probs, n=self.n, conf=conf
        ).sum(axis=1)

    def get_overall_conf(
        self,
        conf: float,
        k: int | None = None,
        *,
        trials: int = 2000,
        dist: pd.DataFrame | None = None,
    ):
        if k is None:  # all within interval
            k = len(self.probs)
        if k > len(self.probs) or k < 0:
            raise ValueError("k must be within 0 and len(probs)")

        results = self.within_interval_counts(
            conf=conf, trials=trials, dist=dist
        ).value_counts()
        total = 0
        for i in range(k, len(self.probs) + 1):
            total += results.get(i)

        return float(total / trials)

    def get_individual_conf(
        self,
        overall_conf: float,
        k: int | None = None,
        *,
        trials: int = 2000,
        accuracy: float = 0.0005,
    ):
        lo = overall_conf
        hi = 1.0
        dist = self.sampling_dist(trials=trials)

        while (hi - lo) > accuracy:
            mid = (lo + hi) / 2
            c = self.get_overall_conf(conf=mid, k=k, dist=dist)
            if c < overall_conf:
                lo = mid
            else:
                hi = mid

        return mid


In [None]:
from ipywidgets import interact_manual, FloatText, IntText, Text


@interact_manual(
    p_i=Text("[0.15, 0.25, 0.3, 0.3]"),
    n=IntText(200),
    k=IntText(4),
    trials=IntText(2000),
    conf=FloatText(0.95),
)
def get_overall_conf(p_i, n, k, trials, conf):
    pop = CategoricalPopulation(
        rng=np.random.default_rng(),
        probs=eval(p_i),
        n=n,
    )
    return pop.get_overall_conf(conf=conf, k=k, trials=trials)


In [None]:
@interact_manual(
    p_i=Text("[0.15, 0.25, 0.3, 0.3]"),
    n=IntText(200),
    k=IntText(4),
    trials=IntText(2000),
    overall_conf=FloatText(0.9),
)
def get_individual_conf(p_i, n, k, trials, overall_conf):
    pop = CategoricalPopulation(
        rng=np.random.default_rng(),
        probs=eval(p_i),
        n=n,
    )
    return pop.get_individual_conf(overall_conf=overall_conf, k=k, trials=trials)
