In [1]:
using CUDA, Random
using GPUArrays: @allowscalar
CUDA.allowscalar(false)
include("utils.jl")
include("statistics.jl")

_var (generic function with 2 methods)

In [2]:
# data dimensions
B = 10  # num. coverage probabilities per boxplot
S = 5   # num. samples per coverage probability
nx = 4  # size of group 1
ny = 3  # size of group 2

# permutation test configuration
pooled = false
alpha = 0.05
deltas = repeat([0], B)
alternative = "two_sided"

# Generate data
Random.seed!(123)
x = randn(Float32, (B, S, nx))
y = randn(Float32, (B, S, ny))
wide = tconf.(eachslice(x, dims=1), eachslice(y, dims=1), alpha=0.0001)
wide = hcat(wide...)
narrow = tconf.(eachslice(x, dims=1), eachslice(y, dims=1), alpha=0.3)
narrow = hcat(narrow...)

px, py = partition(nx, ny)

# Move to GPU
x, y, wide, narrow = cu(x), cu(y), cu(wide), cu(narrow)

([-0.6457307 0.5931974 … 0.48432863 1.1067815; -1.4632514 -0.7684089 … -0.20241024 1.1486968; … ; -1.3416092 0.5053106 … 0.1707967 0.22824632; 0.41216165 -0.53829795 … -1.71712 0.59422016;;; 0.94035435 -0.68205047 … -1.5995136 1.4971732; -0.33309984 0.43087748 … -0.3802952 1.7855924; … ; 3.031836 -1.1985193 … 1.1820064 -0.13788761; -0.36440215 1.8255719 … -0.27137586 -0.66762507;;; 0.8151673 0.7156151 … 1.0914292 -0.26628616; 0.098473065 0.3785804 … -0.8792037 -0.25343972; … ; 1.5091707 -0.123943955 … 0.36522362 -1.0764854; -0.71137244 0.027235571 … 0.15584055 0.89249754;;; 1.7152905 -0.2818589 … 0.24685228 -0.0422467; 0.8787855 -2.3288715 … -1.4717366 -0.971406; … ; 0.15282853 -0.4139169 … -0.26430404 1.0259129; -0.3875826 -0.9293378 … 0.043618564 1.4719257], [-0.60569096 1.3072717 … -0.9987826 -0.8308392; -0.35481334 -0.4582605 … -1.1455126 0.10596891; … ; -0.038758337 0.42520037 … -0.14917293 -1.0552233; -0.7012483 0.50492907 … 0.99210846 -0.92516536;;; -0.85761726 0.19043577 … 0.32

In [3]:
# get first batch for debugging purposes
x1, y1 = x[1,:,:], y[1,:,:]
x1

5×4 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 -0.645731   0.940354   0.815167   1.71529
  0.593197  -0.68205    0.715615  -0.281859
  0.261401   0.140672   3.14363    0.200474
  0.484329  -1.59951    1.09143    0.246852
  1.10678    1.49717   -0.266286  -0.0422467

In [4]:
function pval(x, y, pooled=false, alternative="two-sided", delta=0)
    """
    Parameters
    ----------
    x : Vector{Real}
        Data for group 1
    
    y : Vector{Real}
        Data for group 2
    
    pooled : Bool
        Assume equal/unequal variances for the two groups
    
    alternative : String
        Type of alternative hypothesis
    
    delta : Real
        Null hypothesis difference in means
    
    Returns
    -------
    Float64
        Proportion of pairs among all sample combinations which have
        a test statistic as or more extreme than the original pair (x, y)
    """
    
    x_shift = x .- delta              # shift group 1 under null hypothesis
    t_obs = t(x_shift, y, pooled)  # test statistic for observed data

    combined = vcat(x_shift, y)  # join original pair into single vector
    xs = combined[px]   # get all combinations of pairs from original pair
    ys = combined[py]
    ts = t(xs, ys, pooled)   # test statistic for all possible pairs of samples
    
    if alternative == "smaller"
        n_extreme = count(ts .<= t_obs)
    elseif alternative == "larger"
        n_extreme = count(ts .>= t_obs)
    else
        n_extreme = count(@. (ts <= -abs(t_obs)) | (ts >= abs(t_obs)))
    end

    return n_extreme / size(px, 1)  # proportion of pairs w/ extreme test statistic
    
end


for i in 1:S
    println(pval(x1[i,:], y1[i,:]))  # Must accept 1D array input
end
pval.(eachrow(x1), eachrow(y1))      # But can be vectorized with dot syntax

0.2
0.5142857142857142
0.4857142857142857
0.9142857142857143
0.34285714285714286


5-element Vector{Float64}:
 0.2
 0.5142857142857142
 0.4857142857142857
 0.9142857142857143
 0.34285714285714286

In [5]:
function search(x, y, start, stop;
                pooled=false, alternative="two-sided", isLowerBound=true,
                margin=0.005, threshold=1.0, alpha=0.05)
    
    p_start = pval(x, y, pooled, alternative, start)
    p_end   = pval(x, y, pooled, alternative, stop)
    #println("p_start = ", p_start, ", p_end = ", p_end)
    
    # p-values corresponding to `start` and `stop` must be on opposite sides of `alpha`
    @assert (p_start - alpha) * (p_end - alpha) <= 0

    p = p_new = delta = nothing
    percent_change = (old, new) -> 100 * abs(new-old) / old
    
    i = 0
    while true
        i += 1
        delta = (start + stop) / 2
        p_new = pval(x, y, pooled, alternative, delta)

        if !isnothing(p) && percent_change(p, p_new) <= threshold
            break  # (1) percent change in p-value is below `threshold`
        end
        
        compare = (alpha - p_new) - isLowerBound * 2 * (alpha - p_new)
        if margin < compare
            stop = delta
        elseif margin < -compare
            start = delta
        else
            break  # (2) p-value is within `margin` of `alpha`
        end

        p = p_new
    end
    
    return delta
end

search (generic function with 1 method)

In [6]:
function permInterval(x, y, wide, narrow, delta_true; pooled=false, alpha=0.05, alternative="two_sided")
    """Returns true (false) if permutation test confidence interval does (not) include difference in
    population means.
    Parameters
    ----------
    x1 : Vector{Float64}
        Data for group 1
    x2 : Vector{Float64}
        Data for group 2
    partitions : Tuple{Matrix{Int64}, Matrix{Int64}}
        The i-th rows of x1[partitions[1]] and x2[partitions[2]] denote the i-th arrangement of
        the original (n1+n2) observations into two groups of size n1 and n2.
    delta_true : Float64
        Difference in population means
    pooled : Bool
        Assume pooled or unpooled variances
    alpha : Float64
        Significance level
    alternative : String
        Type of alternative hypothesis ("two-sided", "smaller", "larger")
    Returns
    -------
    Bool
        True (false) if permutation test confidence interval does (not) include difference in population means.
    """

    # use binary search to find approximate permutation test confidence interval
    lo = search(x, y, wide[1], narrow[1],
                pooled=pooled, alpha=alpha, alternative=alternative, isLowerBound=true)
    hi = search(x, y, narrow[2], wide[2],
                pooled=pooled, alpha=alpha, alternative=alternative, isLowerBound=false)
    # println("(", lo, ", ", hi, ")")
    return lo <= delta_true <= hi
end

permInterval (generic function with 1 method)

In [7]:
function coverage(xs, ys, wide, narrow, delta_true; pooled=false, alpha=0.05, alternative="two_sided")
    covered = permInterval.(eachrow(xs), eachrow(ys), wide, narrow, delta_true,
                            pooled=pooled, alpha=alpha, alternative=alternative)
    return count(covered) / S
end

coverage (generic function with 1 method)

In [None]:
coverage.(eachslice(x, dims=1),
          eachslice(y, dims=1),
          eachcol(wide),
          eachcol(narrow),
          eachrow(deltas),
          pooled=pooled,
          alpha=alpha,
          alternative=alternative)