# Multi-Arms MRT Sample Size Calculator

**Modified from:** [MRT-SS Calculator: A Sample Size Calculator for Micro-Randomized Trials](https://d3center.shinyapps.io/mrt-ss-calculator/), by Liao et al. (2016).

ArXiv: [MRT-SS Calculator: An R Shiny Application for Sample Size Calculation in Micro-Randomized Trials](https://arxiv.org/abs/1609.00695)

---

### Enhanced Multi-Arm Flexibility Compared to Liao et al.’s MRT-SS Calculator:

- The original MRT-SS calculator focused on two-arm trials or assumed equal allocation across arms.  
- This implementation supports any number of arms (K ≥ 2) and automatically computes the variance term for multi-arm comparisons.


### Description:

The `mrt_sample_size()` function computes the required sample size for a micro-randomized trial (MRT) focused on a standardized proximal outcome. You provide:

- **Number of intervention arms** (`k`)  
- **Trial duration** (`days`)  
- **Decision points per day**  
- **Expected availability**  
- **Minimum detectable effect size**  
- **Custom randomization probability vector**  

The function then:

1. Derives the variance of the randomized treatment indicator  
2. Applies the large-sample MRT sample-size formula  
3. Optionally adds a “+2” small-sample correction to match the MRT-SS Shiny Calculator    


In [1]:
import math
from typing import Sequence, Optional
from scipy.stats import norm

def mrt_sample_size(
    k: int,
    days: int,
    decision_points_per_day: int,
    availability: float,
    treatment_effect: float,
    rand_prob: Sequence[float],
    baseline_idx: Optional[int] = None,
    alpha: float = 0.05,
    power: float = 0.80,
    sigma: float = 1.0,
    finite_correction: bool = True,
) -> int:
    """
    Calculate the required sample size for a Micro-Randomized Trial (proximal effect).
    
    Parameters
    ----------
    k : int
        Number of intervention arms, must match the length of rand_prob.
    days : int
        Total number of trial days.
    decision_points_per_day : int
        Number of decision points per day.
    availability : float
        Expected availability (0–1).
    treatment_effect : float
        Minimum detectable standardized proximal effect (in units of sigma).
    rand_prob : Sequence[float]
        Randomization probability vector of length k (must sum to 1).
    baseline_idx : int, optional
        Index of the baseline strategy (0–k-1). Defaults to the arm with the highest probability.
    alpha : float, default 0.05
        Significance level (two-sided).
    power : float, default 0.80
        Statistical power (1 - β).
    sigma : float, default 1.0
        Standard deviation of the proximal outcome.
    finite_correction : bool, default True
        Whether to add a small-sample correction of +2 (consistent with the MRT-SS Calculator).
    
    Returns
    -------
    int
        Required number of participants (rounded up).
    """
    # Validate randomization probabilities
    if len(rand_prob) != k:
        raise ValueError("len(rand_prob) must equal k")
    if any(p <= 0 for p in rand_prob):
        raise ValueError("All probabilities in rand_prob must be > 0")
    if abs(sum(rand_prob) - 1) > 1e-8:
        raise ValueError("Sum of rand_prob must equal 1")

    # Select baseline arm
    if baseline_idx is None:
        baseline_idx = rand_prob.index(max(rand_prob))
    if not (0 <= baseline_idx < k):
        raise ValueError("baseline_idx is out of range")

    # Compute Var(A) where baseline=0 and others=1
    mu = 1.0 - rand_prob[baseline_idx]  # E[A]
    var_a = sum(
        p * ((1 if i != baseline_idx else 0) - mu) ** 2
        for i, p in enumerate(rand_prob)
    )

    # Total available decision points per participant
    total_dp = days * decision_points_per_day * availability

    # z-scores for alpha and power
    z_alpha = norm.isf(alpha / 2)  # two-sided
    z_power = norm.isf(1 - power)  # 1 - β

    # Large-sample approximation of n
    n_star = ((z_alpha + z_power) ** 2 * sigma ** 2) / (
        treatment_effect ** 2 * total_dp * var_a
    )

    # Apply small-sample correction if requested
    return math.ceil(n_star) + (2 if finite_correction else 0)



In [None]:
# Example Usage

In [2]:

if __name__ == "__main__":
    # 5-arm example
    params5 = dict(
        k=5,
        days=30,
        decision_points_per_day=1,
        availability=1.0,
        treatment_effect=0.30,
        rand_prob=[0.2, 0.2, 0.2, 0.2, 0.2],
    )
    n5_with = mrt_sample_size(**params5, finite_correction=True)
    n5_without = mrt_sample_size(**params5, finite_correction=False)

    # 3-arm example
    params3 = dict(
        k=3,
        days=30,
        decision_points_per_day=1,
        availability=1.0,
        treatment_effect=0.30,
        rand_prob=[1/3, 1/3, 1/3],
        baseline_idx=2,
    )
    n3_with = mrt_sample_size(**params3, finite_correction=True)
    n3_without = mrt_sample_size(**params3, finite_correction=False)

    # 2-arm example
    params2 = dict(
        k=2,
        days=30,
        decision_points_per_day=1,
        availability=1.0,
        treatment_effect=0.30,
        rand_prob=[0.5, 0.5],
    )
    n2_with = mrt_sample_size(**params2, finite_correction=True)
    n2_without = mrt_sample_size(**params2, finite_correction=False)

    print(f"5-arm MRT (with correction): {n5_with}")
    print(f"5-arm MRT (no correction):   {n5_without}\n")
    print(f"3-arm MRT (with correction): {n3_with}")
    print(f"3-arm MRT (no correction):   {n3_without}\n")
    print(f"2-arm MRT (with correction): {n2_with}")
    print(f"2-arm MRT (no correction):   {n2_without}")


5-arm MRT (with correction): 21
5-arm MRT (no correction):   19

3-arm MRT (with correction): 16
3-arm MRT (no correction):   14

2-arm MRT (with correction): 14
2-arm MRT (no correction):   12
