In [40]:
using Random
using Distributions
using FLoops
using Base.Threads
using BenchmarkTools

include("data.jl")
include("util.jl")

include("intervals/permutation.jl")
include("intervals/bootstrap.jl")
include("intervals/t.jl")

Threads.nthreads()

8

In [14]:
dtype = Float32
seed = 123

# DATA CONFIG

alpha = 0.05

# data
B  = 100   # num. coverage probabilities per boxplot
S  = 4300  # num. samples per coverage probability
nx = 8   # size of group 1
ny = 9    # size of group 2
if binomial(nx+ny, nx) > 30_000
    px, py = partition(nx, ny, 10_000)
else
    px, py = partition(nx, ny)
end

bits = vcat(ones(Int, nx), zeros(Int, ny))
addx, addy = bits[px], bits[py]

([1 1 … 0 0; 1 1 … 0 0; … ; 1 1 … 0 0; 1 0 … 0 0], [0 0 … 0 0; 0 0 … 1 1; … ; 0 0 … 1 1; 0 1 … 1 1])

In [3]:
# POPULATION SETTINGS

Random.seed!(123)

distrTypeX = Gamma{dtype}
X_shape = random(Uniform(2, 4), B)
X_scale = random(Uniform(0.5, 2), B)
distrX = map(distrTypeX, X_shape, X_scale)

distrTypeY = Gumbel{dtype}
Y_loc = random(Uniform(0, 1), B)
Y_scale = random(Uniform(2, 4), B)
distrY = map(distrTypeY, Y_loc, Y_scale)
;

In [4]:
deltas = @. mean(distrX) - mean(distrY)

@show distrX[1:2]
@show distrY[1:2]
@show deltas[1:2];

distrX[1:2] = Gamma{Float32}[Gamma{Float32}(α=3.813f0, θ=1.931f0), Gamma{Float32}(α=2.887f0, θ=1.769f0)]
distrY[1:2] = Gumbel{Float32}[Gumbel{Float32}(μ=0.965f0, θ=2.056f0), Gumbel{Float32}(μ=0.071f0, θ=2.495f0)]
deltas[1:2] = Float32[5.211148, 3.5959504]


In [5]:
Random.seed!(123)
xs = [dtype.(rand(distrX[i], nx, S)) for i in 1:B]
ys = [dtype.(rand(distrY[i], ny, S)) for i in 1:B]
@show size(ys)
ys[2]

size(ys) = (100,)


9×4300 Matrix{Float32}:
 11.8033    -1.09869    1.77469   …   7.44399     2.78415   -0.42669
 -0.443109  -0.590155  -2.27474      -2.87445     1.81832   -4.23349
  2.44806    1.33871    0.65206       7.27469    -2.28245   -2.09474
 -1.61376   -1.32229    0.677952      4.31303    -0.77003    3.49074
  2.9012     1.09572    4.57841       7.7269     -0.548287   1.50053
 -2.78064    1.03421    2.48815   …  -0.625538    9.11552   -4.146
 -0.144154  -1.39356    0.984717      2.92807     0.873052  -0.782865
  1.74981    4.19627    2.87135      -0.802826    3.80352   -0.383334
 -3.9335    -2.3745     2.82539      -0.0943729   8.94879   -2.37657

In [6]:
# flatten into 3D matrix
X = reshape(hcat(xs...), nx, S, B)
Y = reshape(hcat(ys...), ny, S, B)
Y[:,:,2]

9×4300 Matrix{Float32}:
 11.8033    -1.09869    1.77469   …   7.44399     2.78415   -0.42669
 -0.443109  -0.590155  -2.27474      -2.87445     1.81832   -4.23349
  2.44806    1.33871    0.65206       7.27469    -2.28245   -2.09474
 -1.61376   -1.32229    0.677952      4.31303    -0.77003    3.49074
  2.9012     1.09572    4.57841       7.7269     -0.548287   1.50053
 -2.78064    1.03421    2.48815   …  -0.625538    9.11552   -4.146
 -0.144154  -1.39356    0.984717      2.92807     0.873052  -0.782865
  1.74981    4.19627    2.87135      -0.802826    3.80352   -0.383334
 -3.9335    -2.3745     2.82539      -0.0943729   8.94879   -2.37657

In [44]:
function save_ci_results(results, methodId, B, S, pooled=nothing, two_sided=nothing; prefix="", dir="./")
    averages = []

    for batchId in 1:B
        batch = results[methodId, batchId, :]
        coverage = sum([j for (j, _) in batch]) / S
        width = sum([j for (_, j) in batch]) / S
        push!(averages, (coverage, width))
    end
    if isnothing(two_sided)
        save(averages, distrX[1:B], distrY[1:B], alpha, prefix=prefix, dir=dir)
    else
        alpha_ = two_sided ? alpha : alpha / 2
        save(averages, distrX[1:B], distrY[1:B], alpha_, pooled, two_sided, prefix=prefix, dir=dir)
    end
end

function save_permutation_results(results, B, S; prefix="", dir="./")
    i = 1
    per_method = []
    for two_sided in [true, false]
        for pooled in [true, false]
            if i in [2, 4]
                save_ci_results(results, i, B, S, pooled, two_sided, prefix=prefix, dir=dir)
            end
            i += 1
        end
    end
end

save_permutation_results (generic function with 1 method)

In [8]:
struct P
    mean_og   # unshifted group mean
    var_og    # unshifted group variance
    nshift    # number of items to be shifted
    shift_sum # original sum of items to be shifted
    n         # group size
end

function cache(groups, masks)
    mean_og = mean(groups, dims=1)
    var_og = var(groups, dims=1)
    nshift = sum(masks, dims=1)
    shift_sum = sum(groups .* masks, dims=1)
    return P(mean_og, var_og, nshift, shift_sum, size(groups, 1))
end

cache (generic function with 1 method)

In [43]:
x = X[:,1,1]
y = Y[:,1,1]
pooled = vcat(x, y)
xs = pooled[px]
ys = pooled[py]
xcache = cache(xs, addx)
ycache = cache(ys, addy)

wide, narrow = t_estimates(x, y, false)
permInterval(xcache, ycache, wide, narrow, false, 0.05, twoSided, twoSided, 0.005)

(0.3418512207780364, 6.443799878256071)

In [42]:
T = Threads.nthreads()
results = Array{Union{Tuple, Nothing}, 3}(nothing, 6, B, S)

#@time Threads.@threads for (i,j) in collect(Iterators.product(1:B, 1:S)) # 15.52 sec on (B,S) = (5, 1800)
@time @floop ThreadedEx(basesize=div(B*S, T)) for b in 1:10, s in 1:100
    @inbounds x = X[:,s,b]
    @inbounds y = Y[:,s,b]
    
    pooled = vcat(x, y)
    xs = @inbounds pooled[px]
    ys = @inbounds pooled[py]
    xcache = cache(xs, addx)
    ycache = cache(ys, addy)
    wide, narrow = t_estimates(x, y, false)
 
    #results[1, b, s] = permInterval(xcache, ycache, deltas[b], true, alpha, twoSided, twoSided, 0.0005)    
    results[2, b, s] = permInterval(xcache, ycache, wide, narrow, deltas[b], false, alpha, twoSided, twoSided, 0.0005)
    #results[3, b, s] = permInterval(xcache, ycache, deltas[b], true, alpha/2, greater, smaller, 0.0005)
    results[4, b, s] = permInterval(xcache, ycache, deltas[b], false, alpha/2, greater, smaller, 0.0005)
    
    """
    results[5, b, s] = bootstrap(x, y, deltas[b], alpha, nsamples=10_000)
    results[6, b, s] = tconf(x, y, deltas[b], alpha, false)
    """
end

dir = "../results/" * string(nx) * "_" * string(ny) * "/2/"
#save_permutation_results(results, B, S; dir=dir)
#save_ci_results(results, 6, B, S; prefix="bs", dir=dir)

 18.353633 seconds (3.04 M allocations: 46.353 GiB, 11.09% gc time, 0.44% compilation time)


"../results/8_9/2/"