In [9]:
import numpy as np
import pandas as pd

from mh_utils import (
    PermutationPropsal, PermutationStatistic, MCMCISPermutationTest
)
from samc import SAMC
from mcmcis import SmoothedMCMCIS
from variance_estimation import MCMCVarianceEstimation

import seaborn as sns

from itertools import combinations
from math import comb

In [10]:
def permutation_test(
    s1,
    s2,
    stat=None,
    alternative="greater",   # "greater" | "less" | "two-sided"
    exact=True,
    max_enumerations=5_000_000,
):
    """
    Exact (or MC) permutation test for two samples.

    Parameters
    ----------
    s1, s2 : array-like
        Samples (1-D).
    stat : callable or None
        Test statistic function stat(s1, s2) -> float.
        If None: uses difference of means: mean(s1) - mean(s2).
    alternative : {"greater","less","two-sided"}
        Tail definition relative to the observed statistic.
    exact : bool
        If True, enumerate all label permutations (over group membership).
        If total permutations exceed `max_enumerations`, raises ValueError.
        If False, draws a uniform random sample of size `max_enumerations`.
    max_enumerations : int
        For exact=True: hard cap on exact enumerations.
        For exact=False: number of Monte Carlo permutations to sample.

    Returns
    -------
    p_value : float
        Estimated p-value (exact or MC).
    observed : float
        Observed statistic on the original split.
    enumerations_used : int
        Number of permutations actually evaluated.
    exact_flag : bool
        True if exact enumeration was used; False if Monte Carlo sampling.
    """
    s1 = np.asarray(s1, dtype=float).ravel()
    s2 = np.asarray(s2, dtype=float).ravel()
    n1, n2 = len(s1), len(s2)
    N = n1 + n2
    pooled = np.concatenate([s1, s2])

    if stat is None:
        def stat(a, b):  # difference in means
            return float(np.mean(a) - np.mean(b))

    observed = stat(s1, s2)
    total_perms = comb(N, n1)

    if exact:
        if total_perms > max_enumerations:
            raise ValueError(
                f"Exact enumeration requires {total_perms} labelings "
                f"but max_enumerations={max_enumerations}."
            )
        count = 0
        for idx in combinations(range(N), n1):
            mask = np.zeros(N, dtype=bool)
            mask[list(idx)] = True
            t = stat(pooled[mask], pooled[~mask])
            if _tail_indicator(t, observed, alternative):
                count += 1
        p_value = count / total_perms
        return float(p_value), float(observed), int(total_perms), True

    # ---- Monte Carlo: uniform random sample of size max_enumerations ----
    rng = np.random.default_rng()
    hits = 0
    for _ in range(max_enumerations):
        mask = np.zeros(N, dtype=bool)
        # choose group-1 indices uniformly without replacement (within each draw)
        mask[rng.choice(N, size=n1, replace=False)] = True
        t = stat(pooled[mask], pooled[~mask])
        if _tail_indicator(t, observed, alternative):
            hits += 1

    # Plain Monte Carlo estimate; change to (hits+1)/(max_enumerations+1) if you want a small-sample guard
    p_value = hits / max_enumerations
    return float(p_value), float(observed), int(max_enumerations), False


def _tail_indicator(t, t_obs, alternative):
    if alternative == "greater":
        return t >= t_obs
    if alternative == "less":
        return t <= t_obs
    if alternative == "two-sided":
        return abs(t) >= abs(t_obs)
    raise ValueError("alternative must be 'greater', 'less', or 'two-sided'")

In [11]:
class DifferenceInMeans(PermutationStatistic):
    def __call__(self, s1, s2):
        return np.mean(s1) - np.mean(s2)




class SwapProposal(PermutationPropsal):
    def __init__(self, swap_size: int):
        self.swap_size = swap_size
    
    def __call__(self, s1, s2):
        # choose random indices
        idx_s1 = np.random.choice(s1.shape[0], self.swap_size, replace=False)
        idx_s2 = np.random.choice(s2.shape[0], self.swap_size, replace=False)

        new_s1, new_s2 = s1.copy(), s2.copy()
        new_s1[idx_s1] = s2[idx_s2]
        new_s2[idx_s2] = s1[idx_s1]

        return new_s1, new_s2


In [12]:
np.random.seed(111)

s1 = np.random.poisson(5, 15)
s2 = np.random.poisson(2, 10)


diff_in_means = DifferenceInMeans()
proposal_swaps = SwapProposal(swap_size=2)

sv_var = MCMCVarianceEstimation(method="sv")

In [13]:
smcmcis = SmoothedMCMCIS(
    s1=s1,
    s2=s2,
    lambda_star=diff_in_means(s1, s2),
    target_prob=0.01,
    statistic_func=diff_in_means
)

smoothed_mcmcis = MCMCISPermutationTest(
    s1=s1,
    s2=s2,
    J=4,
    T=1_000_00,
    B=25_000,
    ais=smcmcis,
    statistic_func=diff_in_means,
    proposal_func=proposal_swaps,
    variance_estimation_method=sv_var,
    log_scale=False,
    calc_estimator_variance=True,
)

smoothed_mcmcis.run_chain()

### Starting Adaptation Chain 1/4 ###
p-value = 0.00041668405985174336
SD estimate = 0.0001719147708442199
Acceptance rate = 0.7735
### Starting Adaptation Chain 2/4 ###
p-value = 0.00033840220489266475
SD estimate = 0.00019768089398228112
Acceptance rate = 0.6682
### Starting Adaptation Chain 3/4 ###
p-value = 0.00047464134102574333
SD estimate = 0.000257366697201305
Acceptance rate = 0.6329
### Starting Adaptation Chain 4/4 ###
p-value = 0.00045922812677790366
SD estimate = 0.0002561036428109732
Acceptance rate = 0.6265


<mh_utils.MCMCISPermutationTest at 0x111fbe850>

In [14]:
samc = SAMC(
    lambda_star=diff_in_means(s1, s2),
    num_regions=100,
    lower_bound = -1,
    t0=10_000
)

samc_mcmcis = MCMCISPermutationTest(
    s1=s1,
    s2=s2,
    J=2,
    T=1_000_000,
    B=25_000,
    ais=samc,
    statistic_func=diff_in_means,
    proposal_func=proposal_swaps,
    variance_estimation_method=sv_var,
    log_scale=True,
    calc_estimator_variance=False,
)

samc_mcmcis.run_chain()

### Starting Adaptation Chain 1/2 ###


  return np.exp(-self.theta[self._find_region(x)])


p-value = 0.0003638080725444548
Acceptance rate = 0.6403
### Starting Adaptation Chain 2/2 ###
p-value = 0.00043976812774581105
Acceptance rate = 0.6397


<mh_utils.MCMCISPermutationTest at 0x111fbe710>

In [None]:
# Simple example
np.random.seed(111)
p, t_obs, K, is_exact = permutation_test(s1, s2, alternative="greater", exact=False, max_enumerations=1_000_000)
print(p)

0.000363


In [16]:
# Simple example
p, t_obs, K, is_exact = permutation_test(s1, s2, alternative="greater", exact=True)
print(p)

0.0003876087568374552
