In [1]:
import altair as alt
import numpy as np

from code_ramblings.distributions import compute_value_frequencies

SEED = 0
N = 10000


rng = np.random.default_rng(SEED)


poisson = rng.poisson(4, N)
nb = rng.negative_binomial(1, 0.5, N)
nb_poisson = np.concatenate([poisson, nb])


distributions = {
    "poisson": poisson,
    "negative binomial": nb,
    "mixed": nb_poisson,
}

df = compute_value_frequencies(distributions)


base = (
    alt.Chart(df)
    .mark_bar()
    .encode(
        x=alt.X("value:O", title="Value"),
        y=alt.Y("count:Q", title="Frequency"),
        color=alt.Color("distribution:N", title="Distribution"),
    )
)

mixed = base.transform_filter(alt.datum.distribution == "mixed")
split = base.transform_filter(alt.datum.distribution != "mixed")
offset = split.encode(xOffset=alt.XOffset("distribution:N"))


(mixed | split) & offset

In [2]:
from pprint import pprint

from code_ramblings.distributions import PoissonNegBinomMixture

mixture_model = PoissonNegBinomMixture()
mixture_results = mixture_model.fit(nb_poisson)


pprint(mixture_results)
display(mixture_model.sample(100, random_state=0))


distributions = {
    "mixed": nb_poisson,
    "model": mixture_model.sample(N * 2, random_state=0),
}

df = compute_value_frequencies(distributions)


comparison = (
    alt.Chart(df)
    .mark_bar()
    .encode(
        x=alt.X("value:O", title="Value"),
        y=alt.Y("count:Q", title="Frequency"),
        color=alt.Color("distribution:N", title="Distribution"),
        xOffset=alt.XOffset("distribution:N"),
    )
)


comparison

MixtureResults(params=MixtureParams(pi=np.float64(0.5040815986738955),
                                    poisson_lambda=np.float64(4.020087314182715),
                                    negbinom_r=np.float64(0.9953834799503147),
                                    negbinom_p=np.float64(0.5061942728959601)),
               log_likelihood=np.float64(-40806.92373863531),
               converged=True,
               n_iterations=373)


array([ 0,  2,  2,  1,  3,  0,  0,  0,  0,  3,  0,  3,  1,  6,  0,  8,  1,
        1,  2,  6,  4,  3,  4,  0,  1,  6,  3,  0,  0,  0,  1,  2,  7,  5,
        0,  4,  5,  0,  3,  5,  0,  3,  0,  4,  1,  2,  5,  3,  6,  0,  1,
        5,  1,  1,  5,  8,  3,  1,  4, 11,  3,  5,  6,  2,  5,  0,  3,  3,
        3,  2,  0,  0,  5,  2,  6,  3,  0,  0,  1,  7,  0,  5,  3,  0,  5,
        1,  0,  0,  3,  0,  0,  0,  7,  0,  0,  1,  3,  0,  1,  0])