In [1]:
using Revise
using SnpArrays
using LinearAlgebra
using Random
using LoopVectorization
using MendelIHT
using BenchmarkTools

┌ Info: Precompiling SnpArrays [4e780e97-f5bf-4111-9dc4-b70aaf691b06]
└ @ Base loading.jl:1317
┌ Info: Precompiling MendelIHT [921c7187-1484-5754-b919-5d3ed9ac03c4]
└ @ Base loading.jl:1317


## Correctness: No center/scale/impute
We want to do $C = AB$ where
+ $C = m \times n$
+ $A = m \times n$
+ $B = n \times p$
+ SnpArray model = Additive, dominant, recessive

In [28]:
model = ADDITIVE_MODEL
center = false
scale = false
impute = false
m = 4097
n = 1025
p = 1025
x = simulate_random_snparray(undef, m, n)

A = SnpLinAlg{Float64}(x, model=model, impute=impute, center=center, scale=scale)
B = ones(n, p)
C = zeros(m, p)
LinearAlgebra.mul!(C, A, B)

Ctrue = convert(Matrix{Float64}, x, impute=impute, model=model, center=center, scale=scale) * B
@show all(C .≈ Ctrue)

LoadError: MethodError: no method matching vashr(::VectorizationBase.VecUnroll{1, 4, UInt8, VectorizationBase.Vec{4, UInt8}}, ::VectorizationBase.Vec{4, UInt8})
[0mClosest candidates are:
[0m  vashr(::Any, [91m::Static.StaticInt{M}[39m) where M at /Users/biona001/.julia/packages/VectorizationBase/bYx3Z/src/static.jl:38
[0m  vashr([91m::Static.StaticInt{M}[39m, ::Any) where M at /Users/biona001/.julia/packages/VectorizationBase/bYx3Z/src/static.jl:36
[0m  vashr([91m::VectorizationBase.MM{W, X, T1}[39m, ::VectorizationBase.AbstractSIMDVector{W, T2}) where {W, X, T1<:Union{Int16, Int32, Int64, Int8}, T2<:Union{UInt16, UInt32, UInt64, UInt8}} at /Users/biona001/.julia/packages/VectorizationBase/bYx3Z/src/ranges.jl:205
[0m  ...

In [25]:
C

4097×1025 Matrix{Float64}:
 2096.0  2096.0  2096.0  2096.0  2096.0  …  2096.0  2096.0  2096.0  2096.0
    0.0     0.0     0.0     0.0     0.0        0.0     0.0     0.0     0.0
    0.0     0.0     0.0     0.0     0.0        0.0     0.0     0.0     0.0
    0.0     0.0     0.0     0.0     0.0        0.0     0.0     0.0     0.0
 2110.0  2110.0  2110.0  2110.0  2110.0     2110.0  2110.0  2110.0  2110.0
    0.0     0.0     0.0     0.0     0.0  …     0.0     0.0     0.0     0.0
    0.0     0.0     0.0     0.0     0.0        0.0     0.0     0.0     0.0
    0.0     0.0     0.0     0.0     0.0        0.0     0.0     0.0     0.0
 2036.0  2036.0  2036.0  2036.0  2036.0     2036.0  2036.0  2036.0  2036.0
    0.0     0.0     0.0     0.0     0.0        0.0     0.0     0.0     0.0
    0.0     0.0     0.0     0.0     0.0  …     0.0     0.0     0.0     0.0
    0.0     0.0     0.0     0.0     0.0        0.0     0.0     0.0     0.0
 2082.0  2082.0  2082.0  2082.0  2082.0     2082.0  2082.0  2082.0  2082.

In [26]:
Ctrue

4097×1025 Matrix{Float64}:
 524.0  524.0  524.0  524.0  524.0  …  524.0  524.0  524.0  524.0  524.0
 524.0  524.0  524.0  524.0  524.0     524.0  524.0  524.0  524.0  524.0
 525.0  525.0  525.0  525.0  525.0     525.0  525.0  525.0  525.0  525.0
 523.0  523.0  523.0  523.0  523.0     523.0  523.0  523.0  523.0  523.0
 553.0  553.0  553.0  553.0  553.0     553.0  553.0  553.0  553.0  553.0
 524.0  524.0  524.0  524.0  524.0  …  524.0  524.0  524.0  524.0  524.0
 502.0  502.0  502.0  502.0  502.0     502.0  502.0  502.0  502.0  502.0
 531.0  531.0  531.0  531.0  531.0     531.0  531.0  531.0  531.0  531.0
 516.0  516.0  516.0  516.0  516.0     516.0  516.0  516.0  516.0  516.0
 505.0  505.0  505.0  505.0  505.0     505.0  505.0  505.0  505.0  505.0
 499.0  499.0  499.0  499.0  499.0  …  499.0  499.0  499.0  499.0  499.0
 516.0  516.0  516.0  516.0  516.0     516.0  516.0  516.0  516.0  516.0
 512.0  512.0  512.0  512.0  512.0     512.0  512.0  512.0  512.0  512.0
   ⋮                    

## Correctness: center/scale/impute

If we want to center/scale the SnpArray, we have
$$
C_{ij} = \sum_{k} \left(\frac{A_{ik} - \mu_k}{\sigma_k^2}\right)B_{kj} = \sum_{k} \frac{A_{ik}B_{kj} - \mu_kB_{kj}}{\sigma_k^2}
$$

In [51]:
model = ADDITIVE_MODEL
center = true
scale = true
impute = false
m = 4097
n = 1025
p = 1025
x = simulate_random_snparray(undef, m, n)
if impute
    for j in 1:n, i in 1:m
        rand() < 0.01 && (x[i, j] = 0x01) # create ~1% missings
    end
end

A = SnpLinAlg{Float64}(x, model=model, impute=impute, center=center, scale=scale)
B = ones(n, p)
C = zeros(m, p)
LinearAlgebra.mul!(C, A, B)

Ctrue = convert(Matrix{Float64}, x, impute=impute, model=model, center=center, scale=scale) * B
@show all(C .≈ Ctrue)

all(C .≈ Ctrue) = true


true

## Speed: SnpLinAlg-(matrix) vs multiple SnpLinAlg-vector

In [5]:
# C = AB by multiple C[:, i] = AB[:, i]
function adhoc_mul!(
    out::AbstractMatrix, 
    st::AbstractSnpLinAlg,
    v::AbstractMatrix)
    for i in 1:size(v, 2)
        outi = @view(out[:, i])
        vi = @view(v[:, i])
        SnpArrays.mul!(outi, st, vi)
    end
end

adhoc_mul! (generic function with 1 method)

### r = 2

In [11]:
n = 5092   # number of samples
p = 10000  # number of SNPs
r = 2      # number of traits
x = simulate_random_snparray(undef, n, p)

# test correctness
A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=false, center=false, scale=false)
B = ones(p, r)
C = zeros(n, r)
Ctest = zeros(n, r)
LinearAlgebra.mul!(C, A, B)
adhoc_mul!(Ctest, A, B)
all(Ctest .≈ C)

M = 1273, Miter = 1, Mrem = 228
N = 10000, Niter = 9, Nrem = 784
P = 2, Piter = 0, Prem = 2


true

In [16]:
@benchmark LinearAlgebra.mul!($C, $A, $B) # SnpLinAlg-matrix

BenchmarkTools.Trial: 
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     56.664 ms (0.00% GC)
  median time:      59.099 ms (0.00% GC)
  mean time:        59.197 ms (0.00% GC)
  maximum time:     69.516 ms (0.00% GC)
  --------------
  samples:          85
  evals/sample:     1

In [17]:
@benchmark adhoc_mul!($Ctest, $A, $B) # multiple SnpLinAlg-vector

BenchmarkTools.Trial: 
  memory estimate:  288 bytes
  allocs estimate:  4
  --------------
  minimum time:     60.921 ms (0.00% GC)
  median time:      62.778 ms (0.00% GC)
  mean time:        62.858 ms (0.00% GC)
  maximum time:     75.957 ms (0.00% GC)
  --------------
  samples:          80
  evals/sample:     1

In [18]:
Afloat = convert(Matrix{Float64}, A)
BLAS.set_num_threads(8)
@benchmark LinearAlgebra.mul!($C, $Afloat, $B) # BLAS with 8 threads

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     17.510 ms (0.00% GC)
  median time:      17.860 ms (0.00% GC)
  mean time:        18.052 ms (0.00% GC)
  maximum time:     21.712 ms (0.00% GC)
  --------------
  samples:          277
  evals/sample:     1

In [19]:
BLAS.set_num_threads(1)
@benchmark LinearAlgebra.mul!($C, $Afloat, $B) # BLAS with 1 threads

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     48.364 ms (0.00% GC)
  median time:      50.801 ms (0.00% GC)
  mean time:        51.121 ms (0.00% GC)
  maximum time:     56.792 ms (0.00% GC)
  --------------
  samples:          98
  evals/sample:     1

### r = 5

In [8]:
n = 5092
p = 10000
q = 5
x = simulate_random_snparray(undef, n, p)

# test correctness
A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=false, center=false, scale=false)
B = ones(p, q)
C = zeros(n, q)
Ctest = zeros(n, q)
LinearAlgebra.mul!(C, A, B)
adhoc_mul!(Ctest, A, B)
all(Ctest .≈ C)

true

In [9]:
@benchmark LinearAlgebra.mul!($C, $A, $B) # SnpLinAlg-matrix

BenchmarkTools.Trial: 
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     63.364 ms (0.00% GC)
  median time:      66.311 ms (0.00% GC)
  mean time:        66.760 ms (0.00% GC)
  maximum time:     82.072 ms (0.00% GC)
  --------------
  samples:          75
  evals/sample:     1

In [22]:
@benchmark adhoc_mul!($Ctest, $A, $B)

BenchmarkTools.Trial: 
  memory estimate:  720 bytes
  allocs estimate:  10
  --------------
  minimum time:     151.571 ms (0.00% GC)
  median time:      153.847 ms (0.00% GC)
  mean time:        154.455 ms (0.00% GC)
  maximum time:     165.740 ms (0.00% GC)
  --------------
  samples:          33
  evals/sample:     1

In [23]:
Afloat = convert(Matrix{Float64}, A)
BLAS.set_num_threads(8)
@benchmark LinearAlgebra.mul!($C, $Afloat, $B) # BLAS with 8 threads

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     17.992 ms (0.00% GC)
  median time:      18.381 ms (0.00% GC)
  mean time:        18.582 ms (0.00% GC)
  maximum time:     21.774 ms (0.00% GC)
  --------------
  samples:          269
  evals/sample:     1

In [24]:
BLAS.set_num_threads(1)
@benchmark LinearAlgebra.mul!($C, $Afloat, $B) # BLAS with 1 threads

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     54.639 ms (0.00% GC)
  median time:      58.390 ms (0.00% GC)
  mean time:        57.967 ms (0.00% GC)
  maximum time:     63.756 ms (0.00% GC)
  --------------
  samples:          87
  evals/sample:     1

## gesp vs @view

In [12]:
n = 5092
p = 1000
q = 1000
x = simulate_random_snparray(undef, n, p)
A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=false, center=false, scale=false)
B = ones(p, q)
C = zeros(n, q);

In [13]:
@benchmark LinearAlgebra.mul!($C, $A, $B) # gesp

BenchmarkTools.Trial: 
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     1.104 s (0.00% GC)
  median time:      1.113 s (0.00% GC)
  mean time:        1.113 s (0.00% GC)
  maximum time:     1.118 s (0.00% GC)
  --------------
  samples:          5
  evals/sample:     1

In [16]:
@benchmark LinearAlgebra.mul!($C, $A, $B) # @view

BenchmarkTools.Trial: 
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     1.112 s (0.00% GC)
  median time:      1.120 s (0.00% GC)
  mean time:        1.120 s (0.00% GC)
  maximum time:     1.126 s (0.00% GC)
  --------------
  samples:          5
  evals/sample:     1

## $Ax$

In [132]:
using Revise
using SnpArrays
using LinearAlgebra
using Random
using LoopVectorization
using MendelIHT
using Test

# any n between 4097 and 4099 doesn't work!
n = 4093
p = 10000
x = simulate_random_snparray(undef, n, p, min_ma=0)

A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=false, center=false, scale=false)
b = ones(p)
c = A * b
ctrue = convert(Matrix{Float64}, A) * b
@test all(c .≈ ctrue)

M = 1023, Miter = 0, Mrem = 253, rows_filled = 4093 


[32m[1mTest Passed[22m[39m

## $A^tx$

In [8]:
using Revise
using SnpArrays
using LinearAlgebra
using Random
using LoopVectorization
using MendelIHT
using Test

# any n between 8193 and 8195 doesn't work!
n = 8193
p = 1000
x = simulate_random_snparray(undef, n, p, min_ma=0)

A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=false, center=false, scale=false)
b = ones(n)
c = A' * b
ctrue = convert(Matrix{Float64}, A)' * b
@test all(c .≈ ctrue)

[32m[1mTest Passed[22m[39m

## $Ax$ with mean impute

In [23]:
using Revise
using SnpArrays
using LinearAlgebra
using Random
using LoopVectorization
using MendelIHT
using Test

n = 4097
p = 10000
x = simulate_random_snparray(undef, n, p, min_ma=0) # no missing data
x[1, 1025] = 0x01 # missing

A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=true, center=false, scale=false)
Atrue = convert(Matrix{Float64}, x, model=ADDITIVE_MODEL, impute=true, center=false, scale=false)
b = ones(p)
c = A * b
ctrue = Atrue * b
@test all(c .≈ ctrue)

[32m[1mTest Passed[22m[39m

In [7]:
using LinearAlgebra
using Random
using LoopVectorization

function this_gemm_fails(out, s::Matrix{UInt8}, V)
    Vcols = size(V, 2)
    srows = size(s, 1)
    scols = size(s, 2)
    k = srows >> 2
    rem = srows & 3
    @avx for c in 1:Vcols
        for j in 1:scols
            for l in 1:k
                block = s[l, j]
                # unrolled loop
                p = 1
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
                p = 2
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
                p = 3
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
                p = 4
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
            end
        end
    end
    # TODO handle rem
end

function this_gemm_works(out, s::Matrix{UInt8}, V)
    Vcols = size(V, 2)
    srows = size(s, 1)
    scols = size(s, 2)
    k = srows >> 2
    rem = srows & 3
    for c in 1:Vcols
        @avx for j in 1:scols
            for l in 1:k
                block = s[l, j]
                # unrolled loop
                p = 1
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
                p = 2
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
                p = 3
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
                p = 4
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
            end
        end
    end
    # TODO handle rem
end

this_gemm_works (generic function with 1 method)

In [8]:
out = zeros(100, 10)
s = rand(UInt8, 100, 100)
V = rand(100, 10)
this_gemm_works(out, s, V) # runs without error

LoadError: MethodError: no method matching vashr(::VectorizationBase.Vec{4, Float64}, ::VectorizationBase.Vec{4, Float64})
[0mClosest candidates are:
[0m  vashr(::Any, [91m::Static.StaticInt{M}[39m) where M at /Users/biona001/.julia/packages/VectorizationBase/bYx3Z/src/static.jl:38
[0m  vashr([91m::Static.StaticInt{M}[39m, ::Any) where M at /Users/biona001/.julia/packages/VectorizationBase/bYx3Z/src/static.jl:36

In [9]:
this_gemm_fails(out, s, V) # throws vashr error

LoadError: MethodError: no method matching vashr(::VectorizationBase.Vec{4, Float64}, ::VectorizationBase.Vec{4, Float64})
[0mClosest candidates are:
[0m  vashr(::Any, [91m::Static.StaticInt{M}[39m) where M at /Users/biona001/.julia/packages/VectorizationBase/bYx3Z/src/static.jl:38
[0m  vashr([91m::Static.StaticInt{M}[39m, ::Any) where M at /Users/biona001/.julia/packages/VectorizationBase/bYx3Z/src/static.jl:36