In [1]:
using Revise
using SnpArrays
using LinearAlgebra
using Random
using LoopVectorization
using MendelIHT
using BenchmarkTools

┌ Info: Precompiling SnpArrays [4e780e97-f5bf-4111-9dc4-b70aaf691b06]
└ @ Base loading.jl:1317
┌ Info: Precompiling MendelIHT [921c7187-1484-5754-b919-5d3ed9ac03c4]
└ @ Base loading.jl:1317


## Correctness: No center/scale/impute
We want to do $C = AB$ where
+ $C = m \times n$
+ $A = m \times n$
+ $B = n \times p$
+ SnpArray model = Additive, dominant, recessive

In [6]:
model = ADDITIVE_MODEL
center = false
scale = false
impute = false
m = 4097
n = 1025
p = 1025
x = simulate_random_snparray(undef, m, n)

A = SnpLinAlg{Float64}(x, model=model, impute=impute, center=center, scale=scale)
B = ones(n, p)
C = zeros(m, p)
LinearAlgebra.mul!(C, A, B)

Ctrue = convert(Matrix{Float64}, x, impute=impute, model=model, center=center, scale=scale) * B
@show all(C .≈ Ctrue)

all(C .≈ Ctrue) = true


true

## Correctness: center/scale/impute

If we want to center/scale the SnpArray, we have
$$
C_{ij} = \sum_{k} \left(\frac{A_{ik} - \mu_k}{\sigma_k^2}\right)B_{kj} = \sum_{k} \frac{A_{ik}B_{kj} - \mu_kB_{kj}}{\sigma_k^2}
$$

In [5]:
model = ADDITIVE_MODEL
center = true
scale = true
impute = true
m = 4097
n = 1025
p = 1025
x = simulate_random_snparray(undef, m, n)
if impute
    for j in 1:n, i in 1:m
        rand() < 0.01 && (x[i, j] = 0x01) # create ~1% missings
    end
end

A = SnpLinAlg{Float64}(x, model=model, impute=impute, center=center, scale=scale)
B = ones(n, p)
C = zeros(m, p)
LinearAlgebra.mul!(C, A, B)

Ctrue = convert(Matrix{Float64}, x, impute=impute, model=model, center=center, scale=scale) * B
@show all(C .≈ Ctrue)

all(C .≈ Ctrue) = true


true

## Speed: SnpLinAlg-(matrix) vs multiple SnpLinAlg-vector

In [20]:
# C = AB by multiple C[:, i] = AB[:, i]
function adhoc_mul!(
    out::AbstractMatrix, 
    st::AbstractSnpLinAlg,
    v::AbstractMatrix)
    for i in 1:size(v, 2)
        outi = @view(out[:, i])
        vi = @view(v[:, i])
        SnpArrays.mul!(outi, st, vi)
    end
end

adhoc_mul! (generic function with 1 method)

### r = 2

In [21]:
n = 5092   # number of samples
p = 10000  # number of SNPs
r = 2      # number of traits
x = simulate_random_snparray(undef, n, p)

# test correctness
A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=true, center=true, scale=true)
B = ones(p, r)
C = zeros(n, r)
Ctest = zeros(n, r)
LinearAlgebra.mul!(C, A, B)
adhoc_mul!(Ctest, A, B)
all(Ctest .≈ C)

true

In [22]:
@benchmark LinearAlgebra.mul!($C, $A, $B) # SnpLinAlg-matrix

BenchmarkTools.Trial: 
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     33.984 ms (0.00% GC)
  median time:      35.574 ms (0.00% GC)
  mean time:        35.681 ms (0.00% GC)
  maximum time:     48.416 ms (0.00% GC)
  --------------
  samples:          141
  evals/sample:     1

In [23]:
@benchmark adhoc_mul!($Ctest, $A, $B) # multiple SnpLinAlg-vector

BenchmarkTools.Trial: 
  memory estimate:  192 bytes
  allocs estimate:  2
  --------------
  minimum time:     33.801 ms (0.00% GC)
  median time:      35.612 ms (0.00% GC)
  mean time:        44.915 ms (0.00% GC)
  maximum time:     97.297 ms (0.00% GC)
  --------------
  samples:          112
  evals/sample:     1

In [24]:
Afloat = convert(Matrix{Float64}, A)
BLAS.set_num_threads(8)
@benchmark LinearAlgebra.mul!($C, $Afloat, $B) # BLAS with 8 threads

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     17.488 ms (0.00% GC)
  median time:      18.473 ms (0.00% GC)
  mean time:        19.304 ms (0.00% GC)
  maximum time:     23.870 ms (0.00% GC)
  --------------
  samples:          259
  evals/sample:     1

In [25]:
BLAS.set_num_threads(1)
@benchmark LinearAlgebra.mul!($C, $Afloat, $B) # BLAS with 1 threads

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     49.005 ms (0.00% GC)
  median time:      50.821 ms (0.00% GC)
  mean time:        51.015 ms (0.00% GC)
  maximum time:     57.021 ms (0.00% GC)
  --------------
  samples:          98
  evals/sample:     1

### r = 5

In [None]:
n = 5092
p = 10000
r = 5
x = simulate_random_snparray(undef, n, p)

# test correctness
A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=true, center=true, scale=true)
B = ones(p, r)
C = zeros(n, r)
Ctest = zeros(n, r)
LinearAlgebra.mul!(C, A, B)
adhoc_mul!(Ctest, A, B)
all(Ctest .≈ C)

In [28]:
@benchmark LinearAlgebra.mul!($C, $A, $B) # SnpLinAlg-matrix

BenchmarkTools.Trial: 
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     52.632 ms (0.00% GC)
  median time:      56.090 ms (0.00% GC)
  mean time:        62.038 ms (0.00% GC)
  maximum time:     134.616 ms (0.00% GC)
  --------------
  samples:          81
  evals/sample:     1

In [29]:
@benchmark adhoc_mul!($Ctest, $A, $B)

BenchmarkTools.Trial: 
  memory estimate:  480 bytes
  allocs estimate:  5
  --------------
  minimum time:     88.873 ms (0.00% GC)
  median time:      92.703 ms (0.00% GC)
  mean time:        98.438 ms (0.00% GC)
  maximum time:     224.836 ms (0.00% GC)
  --------------
  samples:          51
  evals/sample:     1

In [30]:
Afloat = convert(Matrix{Float64}, A)
BLAS.set_num_threads(8)
@benchmark LinearAlgebra.mul!($C, $Afloat, $B) # BLAS with 8 threads

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     18.761 ms (0.00% GC)
  median time:      20.697 ms (0.00% GC)
  mean time:        21.208 ms (0.00% GC)
  maximum time:     27.362 ms (0.00% GC)
  --------------
  samples:          236
  evals/sample:     1

In [31]:
BLAS.set_num_threads(1)
@benchmark LinearAlgebra.mul!($C, $Afloat, $B) # BLAS with 1 threads

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     56.292 ms (0.00% GC)
  median time:      59.270 ms (0.00% GC)
  mean time:        59.264 ms (0.00% GC)
  maximum time:     66.171 ms (0.00% GC)
  --------------
  samples:          85
  evals/sample:     1

## gesp vs @view

In [12]:
n = 5092
p = 1000
q = 1000
x = simulate_random_snparray(undef, n, p)
A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=false, center=false, scale=false)
B = ones(p, q)
C = zeros(n, q);

In [13]:
@benchmark LinearAlgebra.mul!($C, $A, $B) # gesp

BenchmarkTools.Trial: 
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     1.104 s (0.00% GC)
  median time:      1.113 s (0.00% GC)
  mean time:        1.113 s (0.00% GC)
  maximum time:     1.118 s (0.00% GC)
  --------------
  samples:          5
  evals/sample:     1

In [16]:
@benchmark LinearAlgebra.mul!($C, $A, $B) # @view

BenchmarkTools.Trial: 
  memory estimate:  96 bytes
  allocs estimate:  1
  --------------
  minimum time:     1.112 s (0.00% GC)
  median time:      1.120 s (0.00% GC)
  mean time:        1.120 s (0.00% GC)
  maximum time:     1.126 s (0.00% GC)
  --------------
  samples:          5
  evals/sample:     1

## $Ax$

In [132]:
using Revise
using SnpArrays
using LinearAlgebra
using Random
using LoopVectorization
using MendelIHT
using Test

# any n between 4097 and 4099 doesn't work!
n = 4093
p = 10000
x = simulate_random_snparray(undef, n, p, min_ma=0)

A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=false, center=false, scale=false)
b = ones(p)
c = A * b
ctrue = convert(Matrix{Float64}, A) * b
@test all(c .≈ ctrue)

M = 1023, Miter = 0, Mrem = 253, rows_filled = 4093 


[32m[1mTest Passed[22m[39m

## $A^tx$

In [8]:
using Revise
using SnpArrays
using LinearAlgebra
using Random
using LoopVectorization
using MendelIHT
using Test

# any n between 8193 and 8195 doesn't work!
n = 8193
p = 1000
x = simulate_random_snparray(undef, n, p, min_ma=0)

A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=false, center=false, scale=false)
b = ones(n)
c = A' * b
ctrue = convert(Matrix{Float64}, A)' * b
@test all(c .≈ ctrue)

[32m[1mTest Passed[22m[39m

## $Ax$ with mean impute

In [23]:
using Revise
using SnpArrays
using LinearAlgebra
using Random
using LoopVectorization
using MendelIHT
using Test

n = 4097
p = 10000
x = simulate_random_snparray(undef, n, p, min_ma=0) # no missing data
x[1, 1025] = 0x01 # missing

A = SnpLinAlg{Float64}(x, model=ADDITIVE_MODEL, impute=true, center=false, scale=false)
Atrue = convert(Matrix{Float64}, x, model=ADDITIVE_MODEL, impute=true, center=false, scale=false)
b = ones(p)
c = A * b
ctrue = Atrue * b
@test all(c .≈ ctrue)

[32m[1mTest Passed[22m[39m

## vashr error

In [11]:
using LinearAlgebra
using Random
using LoopVectorization
using Test

expr = :((Aij == 3) * V[j, c])

function my_gemm(out, s::Matrix{UInt8}, V)
    fill!(out, 0)
    Vcols = size(V, 2)
    srows = size(s, 1)
    scols = size(s, 2)
    for c in 1:Vcols
        for j in 1:scols
            for i in 1:srows
                ip3 = i + 3
                Aij = (s[ip3 >> 2, j] >> ((ip3 & 0x03) << 1)) & 0x03
                out[i, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
            end
        end
    end
end

function my_gemm_avx(out, s::Matrix{UInt8}, V)
    fill!(out, 0)
    Vcols = size(V, 2)
    srows = size(s, 1)
    scols = size(s, 2)
    @avx for c in 1:Vcols
        for j in 1:scols
            for i in 1:srows
                ip3 = i + 3
                Aij = (s[ip3 >> 2, j] >> ((ip3 & 0x03) << 1)) & 0x03
                out[i, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
            end
        end
    end
end

function my_gemm_unroll(out, s::Matrix{UInt8}, V)
    fill!(out, 0)
    Vcols = size(V, 2)
    srows = size(s, 1)
    scols = size(s, 2)
    k = srows >> 2
    rem = srows & 3
    @avx for c in 1:Vcols
        for j in 1:scols
            for l in 1:k
                block = s[l, j]
                for p in 1:4
                    Aij = (block >> (2 * (p - 1))) & 3
                    out[4*(l - 1) + p, c] += $expr
                end
            end
        end
    end
    # TODO handle rem
end

function my_gemm_manual_unroll(out, s::Matrix{UInt8}, V)
    fill!(out, 0)
    Vcols = size(V, 2)
    srows = size(s, 1)
    scols = size(s, 2)
    k = srows >> 2
    rem = srows & 3
    @avx for c in 1:Vcols
        for j in 1:scols
            for l in 1:k
                block = s[l, j]
                # unrolled loop
                p = 1
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
                p = 2
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
                p = 3
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
                p = 4
                Aij = (block >> (2 * (p - 1))) & 3
                out[4*(l - 1) + p, c] += ((Aij >= 2) + (Aij == 3)) * V[j, c]
            end
        end
    end
    # TODO handle rem
end

my_gemm_manual_unroll (generic function with 1 method)

In [20]:
out_true = zeros(100, 10)
out_test1 = zeros(100, 10)
out_test2 = zeros(100, 10)
out_test3 = zeros(100, 10)
s = rand(UInt8, 25, 100)
V = rand(100, 10)

my_gemm(out_true, s, V)
my_gemm_avx(out_test1, s, V)
@test all(out_true .≈ out_test1)

[32m[1mTest Passed[22m[39m

In [21]:
my_gemm_manual_unroll(out_test2, s, V)
@test all(out_true .≈ out_test2)

[91m[1mTest Failed[22m[39m at [39m[1mIn[21]:3[22m
  Expression: all(out_true .≈ out_test2)


LoadError: [91mThere was an error during testing[39m

In [12]:
my_gemm_unroll(out_test3, s, V)

LoadError: MethodError: no method matching vashr(::VectorizationBase.VecUnroll{1, 4, UInt8, VectorizationBase.Vec{4, UInt8}}, ::VectorizationBase.Vec{4, UInt8})
[0mClosest candidates are:
[0m  vashr(::Any, [91m::Static.StaticInt{M}[39m) where M at /Users/biona001/.julia/packages/VectorizationBase/ax8db/src/static.jl:43
[0m  vashr([91m::Static.StaticInt{M}[39m, ::Any) where M at /Users/biona001/.julia/packages/VectorizationBase/ax8db/src/static.jl:42
[0m  vashr([91m::VectorizationBase.MM{W, X, T1}[39m, ::VectorizationBase.AbstractSIMDVector{W, T2}) where {W, X, T1<:Union{Int16, Int32, Int64, Int8}, T2<:Union{UInt16, UInt32, UInt64, UInt8}} at /Users/biona001/.julia/packages/VectorizationBase/ax8db/src/ranges.jl:207
[0m  ...