# Minimizing allocations

It seems like [any allocation is deterimental to multithreading](https://discourse.julialang.org/t/poor-performance-on-cluster-multithreading/12248), since garbage collection is single threaded (at least as of 2018). We cannot pre-allocate `M`, but all other intermediate matrices (e.g. Xwork, Hwork, N, temporary vectors) can be preallocated and scaled. 

In [1]:
using Revise
using VCFTools
using MendelImpute
using GeneticVariation
using Random
using SparseArrays
using JLD2, FileIO, JLSO
using ProgressMeter
using GroupSlices
using ThreadPools
using BenchmarkTools
# using Plots
# using ProfileView

In [2]:
Threads.nthreads()

8

# Not yet optimized (7/11/2020)

### window by window intersection with global search

In [5]:
Random.seed!(2020)
width   = 512
tgtfile = "./compare2/target.typedOnly.maf0.01.masked.vcf.gz"
reffile = "./compare2/ref.excludeTarget.w$width.jlso"
outfile = "./compare2/mendel.imputed.vcf.gz"
@time ph = phase(tgtfile, reffile, outfile = outfile, width = width,
    dynamic_programming = false);

X_mendel = convert_gt(Float32, outfile)
X_complete = convert_gt(Float32, "./compare2/target.full.vcf.gz")
n, p = size(X_mendel)
println("error_rate = ", sum(X_mendel .!= X_complete) / n / p)
rm(outfile, force=true)

Importing reference haplotype data...


[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m


Total windows = 72, averaging ~ 627 unique haplotypes per window.

Timings: 
    Data import                     = 7.59873 seconds
    Computing haplotype pair        = 2.23046 seconds
        BLAS3 mul! to get M and N      = 0.0580204 seconds per thread
        haplopair search               = 1.88659 seconds per thread
        finding redundant happairs     = 0.0253855 seconds per thread
    Phasing by win-win intersection = 0.18561 seconds
    Imputation                      = 6.8954 seconds

 17.168989 seconds (77.19 M allocations: 7.610 GiB, 6.38% gc time)
error_rate = 8.32693826254268e-5


### haplotype thinning

In [6]:
Random.seed!(2020)
width   = 512
tgtfile = "./compare2/target.typedOnly.maf0.01.masked.vcf.gz"
reffile = "./compare2/ref.excludeTarget.w$width.jlso"
outfile = "./compare2/mendel.imputed.vcf.gz"
@time ph = phase(tgtfile, reffile, outfile = outfile, width = width,
    dynamic_programming = false, thinning_factor=100, max_haplotypes=100);

X_mendel = convert_gt(Float32, outfile)
X_complete = convert_gt(Float32, "./compare2/target.full.vcf.gz")
n, p = size(X_mendel)
println("error_rate = ", sum(X_mendel .!= X_complete) / n / p)
rm(outfile, force=true)

Importing reference haplotype data...


[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m


Total windows = 72, averaging ~ 627 unique haplotypes per window.

Timings: 
    Data import                     = 7.80156 seconds
    Computing haplotype pair        = 3.30648 seconds
        screening for top haplotypes   = 0.402203 seconds per thread
        BLAS3 mul! to get M and N      = 2.35561 seconds per thread
        haplopair search               = 0.0949087 seconds per thread
        finding redundant happairs     = 0.0348157 seconds per thread
    Phasing by win-win intersection = 0.177686 seconds
    Imputation                      = 6.84377 seconds

 18.479679 seconds (77.72 M allocations: 7.541 GiB, 6.05% gc time)
error_rate = 8.550487693659426e-5


### Lasso

In [10]:
Random.seed!(2020)
width   = 512
tgtfile = "./compare2/target.typedOnly.maf0.01.masked.vcf.gz"
reffile = "./compare2/ref.excludeTarget.w$width.jlso"
outfile = "./compare2/mendel.imputed.vcf.gz"
@time hs, ph = phase(tgtfile, reffile, outfile = outfile, width = width,
    lasso = 20, dynamic_programming=false, max_haplotypes=100);

# import imputed result and compare with true
X_mendel = convert_gt(Float32, outfile)
# X_complete = convert_gt(Float32, "./compare2/target.full.vcf.gz")
n, p = size(X_mendel)
println("error_rate = ", sum(X_mendel .!= X_complete) / n / p)
rm(outfile, force=true)

Importing reference haplotype data...


[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m


Total windows = 72, averaging ~ 627 unique haplotypes per window.

Timings: 
    Data import                     = 7.80041 seconds
    Computing haplotype pair        = 0.683484 seconds
        BLAS3 mul! to get M and N      = 0.104721 seconds per thread
        haplopair search               = 0.307949 seconds per thread
        finding redundant happairs     = 0.0280975 seconds per thread
    Phasing by win-win intersection = 0.179585 seconds
    Imputation                      = 6.49275 seconds

 15.390669 seconds (77.19 M allocations: 7.610 GiB, 7.79% gc time)
error_rate = 8.685062226819259e-5


# Optimized

In [4]:
# global search: Preallocate all vectors/matrices
Random.seed!(2020)
width   = 512
tgtfile = "./compare2/target.typedOnly.maf0.01.masked.vcf.gz"
reffile = "./compare2/ref.excludeTarget.w$width.jlso"
outfile = "./compare2/mendel.imputed.vcf.gz"
@time ph = phase(tgtfile, reffile, outfile = outfile, width = width,
    dynamic_programming = false);

X_mendel = convert_gt(Float32, outfile)
X_complete = convert_gt(Float32, "./compare2/target.full.vcf.gz")
n, p = size(X_mendel)
println("error_rate = ", sum(X_mendel .!= X_complete) / n / p)
rm(outfile, force=true)

Importing reference haplotype data...


[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:07[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m


Total windows = 72, averaging ~ 627 unique haplotypes per window.

Timings: 
    Data import                     = 8.25352 seconds
    Computing haplotype pair        = 2.23927 seconds
        BLAS3 mul! to get M and N      = 0.0659687 seconds per thread
        haplopair search               = 1.84551 seconds per thread
        finding redundant happairs     = 0.0299791 seconds per thread
    Phasing by win-win intersection = 0.18616 seconds
    Imputation                      = 6.77303 seconds

 17.844942 seconds (77.19 M allocations: 7.420 GiB, 7.17% gc time)
error_rate = 8.32693826254268e-5


In [3]:
# thinning: preallocate all vectors/matrices
Random.seed!(2020)
width   = 512
tgtfile = "./compare2/target.typedOnly.maf0.01.masked.vcf.gz"
reffile = "./compare2/ref.excludeTarget.w$width.jlso"
outfile = "./compare2/mendel.imputed.vcf.gz"
@time ph = phase(tgtfile, reffile, outfile = outfile, width = width,
    dynamic_programming = false, thinning_factor=100, max_haplotypes=100);

X_mendel = convert_gt(Float32, outfile)
X_complete = convert_gt(Float32, "./compare2/target.full.vcf.gz")
n, p = size(X_mendel)
println("error_rate = ", sum(X_mendel .!= X_complete) / n / p)
rm(outfile, force=true)

Importing reference haplotype data...


[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m


Total windows = 72, averaging ~ 627 unique haplotypes per window.

Timings: 
    Data import                     = 7.63071 seconds
    Computing haplotype pair        = 3.21096 seconds
        screening for top haplotypes   = 0.62019 seconds per thread
        BLAS3 mul! to get M and N      = 2.25471 seconds per thread
        haplopair search               = 0.0934914 seconds per thread
        finding redundant happairs     = 0.0637935 seconds per thread
    Phasing by win-win intersection = 0.194177 seconds
    Imputation                      = 6.52209 seconds

 17.627988 seconds (77.20 M allocations: 7.317 GiB, 5.14% gc time)
error_rate = 8.550487693659426e-5


In [10]:
# lasso: preallocate all vectors/matrices
Random.seed!(2020)
width   = 512
tgtfile = "./compare2/target.typedOnly.maf0.01.masked.vcf.gz"
reffile = "./compare2/ref.excludeTarget.w$width.jlso"
outfile = "./compare2/mendel.imputed.vcf.gz"
@time hs, ph = phase(tgtfile, reffile, outfile = outfile, width = width,
    lasso = 20, dynamic_programming=false, max_haplotypes=100);

# import imputed result and compare with true
X_mendel = convert_gt(Float32, outfile)
X_complete = convert_gt(Float32, "./compare2/target.full.vcf.gz")
n, p = size(X_mendel)
println("error_rate = ", sum(X_mendel .!= X_complete) / n / p)
rm(outfile, force=true)

Importing reference haplotype data...


[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:07[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m


Total windows = 72, averaging ~ 627 unique haplotypes per window.

Timings: 
    Data import                     = 8.60205 seconds
    Computing haplotype pair        = 3.22841 seconds
        BLAS3 mul! to get M and N      = 0.0361663 seconds per thread
        haplopair search               = 0.180822 seconds per thread
        finding redundant happairs     = 0.0371089 seconds per thread
    Phasing by win-win intersection = 0.216741 seconds
    Imputation                      = 7.02347 seconds

 19.152174 seconds (77.19 M allocations: 7.420 GiB, 6.80% gc time)
error_rate = 8.685062226819259e-5


# Lets do rigorous benchmarks

In [5]:
# first import all data, declare a bunch of (needed or not) variables, and look at 1 window
Random.seed!(2020)
width   = 512
tgtfile = "./compare2/target.typedOnly.maf0.01.masked.vcf.gz"
reffile = "./compare2/ref.excludeTarget.w$width.jlso"

loaded = JLSO.load(reffile)
compressed_Hunique = loaded[:compressed_Hunique]
X, X_sampleID, X_chr, X_pos, X_ids, X_ref, X_alt = VCFTools.convert_gt(UInt8, tgtfile, trans=true, save_snp_info=true, msg = "Importing genotype file...");

people = size(X, 2)
tgt_snps = size(X, 1)
ref_snps = length(compressed_Hunique.pos)
tot_windows = floor(Int, tgt_snps / width)
avg_num_unique_haps = round(Int, avg_haplotypes_per_window(compressed_Hunique))
max_windows_per_chunks = nchunks(avg_num_unique_haps, nhaplotypes(compressed_Hunique), width, people, Threads.nthreads(), Base.summarysize(X), compressed_Hunique)
chunks = ceil(Int, tot_windows / min(tot_windows, max_windows_per_chunks))
num_windows_per_chunks = round(Int, tot_windows / chunks)
snps_per_chunk = num_windows_per_chunks * width
last_chunk_windows = tot_windows - (chunks - 1) * num_windows_per_chunks

ph = [HaplotypeMosaicPair(ref_snps) for i in 1:people]
redundant_haplotypes = [OptimalHaplotypeSet(num_windows_per_chunks, nhaplotypes(compressed_Hunique)) for i in 1:people]

chunk = 1
windows = (chunk == chunks ? last_chunk_windows : num_windows_per_chunks)
w_start = (chunk - 1) * num_windows_per_chunks + 1
w_end = (chunk == chunks ? tot_windows : chunk * num_windows_per_chunks)

MendelImpute.initialize!(redundant_haplotypes)
chunk == chunks && MendelImpute.resize!(redundant_haplotypes, last_chunk_windows)

winrange = w_start:w_end
people = size(X, 2)
ref_snps = length(compressed_Hunique.pos)
width = compressed_Hunique.width
windows = length(winrange)
threads = Threads.nthreads()
tothaps = nhaplotypes(compressed_Hunique)

# working arrys 
happair1 = [ones(Int32, people)           for _ in 1:threads]
happair2 = [ones(Int32, people)           for _ in 1:threads]
hapscore = [zeros(Float32, size(X, 2))    for _ in 1:threads]
Xwork    = [zeros(Float32, width, people) for _ in 1:threads]
redunhaps_bitvec1 = [falses(tothaps) for _ in 1:threads]
redunhaps_bitvec2 = [falses(tothaps) for _ in 1:threads]

# window 1
absolute_w = 1
Hw_aligned = compressed_Hunique.CW_typed[absolute_w].uniqueH
Xw_idx_start = (absolute_w - 1) * width + 1
Xw_idx_end = absolute_w * width
Xw_aligned = view(X, Xw_idx_start:Xw_idx_end, :)
id = Threads.threadid();

[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:06[39m


In [3]:
# global search must allocate M, Hwork, N for each window, but there's no other allocation
M = zeros(Float32, size(Hw_aligned, 2), size(Hw_aligned, 2))
N = zeros(Float32, size(Xw_aligned, 2), size(Hw_aligned, 2))
Hwork = convert(Matrix{Float32}, Hw_aligned)
@benchmark haplopair!($Xw_aligned, $Hw_aligned, happair1=$(happair1[id]), 
    happair2=$(happair2[id]), hapscore=$(hapscore[id]), Xwork=$(Xwork[id]),
    M=$M, N=$N, Hwork=$Hwork)

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     121.945 ms (0.00% GC)
  median time:      124.566 ms (0.00% GC)
  mean time:        124.980 ms (0.00% GC)
  maximum time:     131.547 ms (0.00% GC)
  --------------
  samples:          41
  evals/sample:     1

In [5]:
# thinning must allocate a bunch of vector/matrices. Only R, Hwork cannot be preallocated
# all allocations are due to distance computation.
keep    = 100
maxindx = zeros(Int, keep)
maxgrad = zeros(Float32, keep)
Xi      = zeros(Float32, size(Xw_aligned, 1))
N       = zeros(Float32, keep)
Hk    = zeros(Float32, size(Hw_aligned, 1), keep)
M     = zeros(Float32, keep, keep)
Xwork = zeros(Float32, size(Xw_aligned, 1), size(Xw_aligned, 2))
Hwork = convert(Matrix{Float32}, Hw_aligned)
R     = rand(Float32, size(Hw_aligned, 2), size(Xw_aligned, 2))
@benchmark haplopair_thin_BLAS2!($Xw_aligned, $Hw_aligned, allele_freq=nothing, 
    keep=$keep, happair1=$(happair1[id]), happair2=$(happair2[id]), 
    hapscore=$(hapscore[id]), maxindx=$maxindx, maxgrad=$maxgrad, Xi=$Xi, N=$N, 
    Hk=$Hk, M=$M, Xwork=$Xwork, Hwork=$Hwork, R=$R)

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     211.456 ms (0.00% GC)
  median time:      225.450 ms (0.00% GC)
  mean time:        225.303 ms (0.00% GC)
  maximum time:     244.914 ms (0.00% GC)
  --------------
  samples:          23
  evals/sample:     1

In [4]:
# lasso: here Hwork, M, Nt cannot be preallocated, but there's no other allocation.
r       = 10
maxindx = zeros(Int, r)
maxgrad = zeros(Float32, r)
Xwork = zeros(Float32, size(Xw_aligned, 1), size(Xw_aligned, 2))
Hwork = convert(Matrix{Float32}, Hw_aligned)
M     = zeros(Float32, size(Hw_aligned, 2), size(Hw_aligned, 2))
Nt    = zeros(Float32, size(Hw_aligned, 2), size(Xw_aligned, 2))
@benchmark haplopair_lasso!($Xw_aligned, $Hw_aligned, r=$r, happair1=$(happair1[id]), 
    happair2=$(happair2[id]), hapscore=$(hapscore[id]), maxindx=$maxindx, 
    maxgrad=$maxgrad, Xwork=$Xwork, Hwork=$Hwork, M=$M, Nt=$Nt)

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     26.048 ms (0.00% GC)
  median time:      27.389 ms (0.00% GC)
  mean time:        27.534 ms (0.00% GC)
  maximum time:     32.354 ms (0.00% GC)
  --------------
  samples:          182
  evals/sample:     1

In [12]:
# computing redundant haps have allocation in first chunk
w = something(findfirst(x -> x == absolute_w, winrange)) # window index of current chunk
redundant_haplotypes = [OptimalHaplotypeSet(num_windows_per_chunks, nhaplotypes(compressed_Hunique)) for i in 1:people]
@time compute_redundant_haplotypes!(redundant_haplotypes, compressed_Hunique, 
    (happair1[id]), (happair2[id]), w, absolute_w, (redunhaps_bitvec1[id]), 
    (redunhaps_bitvec2[id]))

  0.005366 seconds (4.00 k allocations: 9.461 MiB)


In [13]:
# computing redundant haps have 0 allocation in subsequent chunks
w = something(findfirst(x -> x == absolute_w, winrange)) # window index of current chunk
@benchmark compute_redundant_haplotypes!($redundant_haplotypes, $compressed_Hunique, 
    $(happair1[id]), $(happair2[id]), $w, $absolute_w, $(redunhaps_bitvec1[id]), 
    $(redunhaps_bitvec2[id]))

BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     593.756 μs (0.00% GC)
  median time:      607.433 μs (0.00% GC)
  mean time:        617.867 μs (0.00% GC)
  maximum time:     1.646 ms (0.00% GC)
  --------------
  samples:          8068
  evals/sample:     1

# Writing routine

In [10]:
# create variables
Random.seed!(2020)
width   = 512
tgtfile = "./compare2/target.typedOnly.maf0.01.masked.vcf.gz"
reffile = "./compare2/ref.excludeTarget.w$width.jlso"
outfile = "./compare2/mendel.imputed.vcf.gz"
@time ph = phase(tgtfile, reffile, outfile = outfile, width = width,
    dynamic_programming = false);

loaded = JLSO.load(reffile)
compressed_Hunique = loaded[:compressed_Hunique]

X, X_sampleID, X_chr, X_pos, X_ids, X_ref, X_alt = VCFTools.convert_gt(UInt8, tgtfile, trans=true, save_snp_info=true, msg = "Importing genotype file...")
XtoH_idx = indexin(X_pos, compressed_Hunique.pos) # X_pos[i] == H_pos[XtoH_idx[i]]
X_full = Matrix{Union{Missing, UInt8}}(missing, length(compressed_Hunique.pos), size(X, 2))
copyto!(@view(X_full[XtoH_idx, :]), X); # keep known entries
impute_discard_phase!(X_full, compressed_Hunique, ph)

Importing reference haplotype data...


[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:07[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m


Total windows = 72, averaging ~ 627 unique haplotypes per window.

Timings: 
    Data import                     = 8.4725 seconds
    Computing haplotype pair        = 3.33952 seconds
        BLAS3 mul! to get M and N      = 0.0983789 seconds per thread
        haplopair search               = 2.90371 seconds per thread
        finding redundant happairs     = 0.0345083 seconds per thread
    Phasing by win-win intersection = 0.287375 seconds
    Imputation                      = 7.33042 seconds

 19.804090 seconds (78.25 M allocations: 7.471 GiB, 5.16% gc time)


[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:07[39m


In [None]:
@time write(outfile, X_full, compressed_Hunique, X_sampleID)

In [18]:
# original
@btime write($outfile, $X_full, $compressed_Hunique, $X_sampleID) seconds=30

[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:05[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to 

  5.933 s (414103 allocations: 362.05 MiB)


In [16]:
# copy
@btime write($outfile, $X_full, $compressed_Hunique, $X_sampleID) seconds=30

[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to file...100%|█████████████████████████████████| Time: 0:00:06[39m
[32mWriting to 

  6.211 s (414104 allocations: 362.05 MiB)


# Optimize window by window intersection

In [None]:
using Revise
using VCFTools
using MendelImpute
using GeneticVariation
using Random
using SparseArrays
using JLD2, FileIO, JLSO
using ProgressMeter
using GroupSlices
using ThreadPools
using BenchmarkTools
# using Plots
# using ProfileView

In [10]:
# first import all data, declare a bunch of (needed or not) variables, and look at 1 window
cd("/Users/biona001/.julia/dev/MendelImpute/simulation")
Random.seed!(2020)
width   = 512
tgtfile = "./compare2/target.typedOnly.maf0.01.masked.vcf.gz"
reffile = "./compare2/ref.excludeTarget.w$width.jlso"
loaded = JLSO.load(reffile)
compressed_Hunique = loaded[:compressed_Hunique]
X, X_sampleID, X_chr, X_pos, X_ids, X_ref, X_alt = VCFTools.convert_gt(UInt8, tgtfile, 
    trans=true, save_snp_info=true, msg = "Importing genotype file...");

[32mImporting genotype file...100%|█████████████████████████| Time: 0:00:06[39m


In [None]:
"""
    phase_sample!(strand1, strand2, happair1, happair2, compressed_Hunique)

For a sample, intersects redundant haplotype pairs window by window, 
then search for breakpoints.

# Arguments
- `strand1`: Possible strand 1 genotypes. `s1[w]` stores possible haplotype indices in window `w`. 
- `strand2`: Possible strand 2 genotypes. `s2[w]` stores possible haplotype indices in window `w`. 
- `happair1`: `happair1[w]` stores best haplotype index for strand1 in window `w`
- `happair2`: `happair2[w]` stores best haplotype index for strand2 in window `w`
- `compressed_Hunique`: A `CompressedHaplotypes` object

# Optional storage argument
- `storage1`: A ∩ C
- `storage2`: B ∩ D
- `storage3`: A ∩ D
- `storage4`: B ∩ C

These are needed because intersection can happen in 4 ways:
A   B      A   B
|   |  or    X
C   D      C   D
"""
function phase_sample!(
    strand1::Vector{BitVector}
    strand2::Vector{BitVector}
    happair1::AbstractVector{<:Integer},
    happair2::AbstractVector{<:Integer},
    compressed_Hunique::CompressedHaplotypes,
    # working arrays
    storage1::BitVector = falses(nhaplotypes(compressed_Hunique)),
    storage2::BitVector = falses(nhaplotypes(compressed_Hunique)),
    storage3::BitVector = falses(nhaplotypes(compressed_Hunique)),
    storage4::BitVector = falses(nhaplotypes(compressed_Hunique)),
    )
    
    windows = length(strand1)
    windows == length(strand2) == length(happair1) == length(happair2) ||
         error("strand1, happair1, happair2, and happair2 have different length.")
    lifespan = (1, 1) # counter to track survival time
    
    # get first window's optimal haplotypes
    h1 = happair1[1]
    h2 = happair2[1]
    survivors1 = get(compressed_Hunique.CW_typed[1].hapmap, h1, h1)
    survivors2 = get(compressed_Hunique.CW_typed[1].hapmap, h2, h2)
    store!(storage1, survivors1)
    store!(storage2, survivors2)
    store!(storage3, survivors1)
    store!(storage4, survivors2)
    
    for w in 2:windows
        # convert integer indices to a bitarray
        h1 = happair1[w]
        h2 = happair2[w]
        h1set = get(compressed_Hunique.CW_typed[w].hapmap, h1, h1)
        h2set = get(compressed_Hunique.CW_typed[w].hapmap, h2, h2)
        store!(strand1[w], h1set)
        store!(strand2[w], h2set)
        
        # heuristic to decide whether cross-over is better
        crossed = false
        storage1 .= survivors1 .& strand1[w] # A ∩ C
        storage2 .= survivors2 .& strand2[w] # B ∩ D
        storage3 .= survivors1 .& strand2[w] # A ∩ D
        storage4 .= survivors2 .& strand1[w] # B ∩ C
        if sum(storage1) + sum(storage2) < sum(storage3) + sum(storage4)
            strand1[w], strand2[w] = strand2[w], strand1[w]
            crossed = true
        end
        
        # update strand 1 and 2
        if crossed
            check_intersection!(strand1, storage3, survivors1, w, lifespan[1])
            check_intersection!(strand2, storage4, survivors2, w, lifespan[2])
        else
            check_intersection!(strand1, storage1, survivors1, w, lifespan[1])
            check_intersection!(strand2, storage2, survivors2, w, lifespan[2])
        end
        
    end
    
end

In [None]:
"""
    store!(storage, v)

Stores integers in `v` to `storage` as BitVector form. `strand[i] = true` 
for all `i ∈ v`. 
"""
function store!(strand::BitVector, v::AbstractVector)
    fill!(strand, false)
    for i in v
        @inbounds strand[i] = true
    end
    return nothing
end
function store!(strand::BitVector, v::Integer)
    fill!(strand, false)
    strand[v] = true
    return nothing
end
"""
    update_strand!(strand, survivors, window, lifespan)

Update `strand` and `survivors` after intersection. If there are no survivors, 
fills `survivors` with haplotypes from nearest unchecked window and prune the 
previous `lifespan` windows. 

# Arguments
- `strand`: possible haplotypes for each window.
- `survivors`: Current surviving haplotypes after intersection
- `survivors_prev`: Surviving haplotypes before intersection
- `window`: current window
- `lifespan`: Number of windows haplotypes in `survivors_prev` survived 
"""
function update_strand!(
    strand::Vector{BitVector}, 
    survivors::BitVector, 
    survivors_prev::BitVector, 
    window::Int, 
    lifespan::Int,
    )
    if sum(survivors) == 0
        # delete all nonmatching haplotypes in previous windows
        for w in (window - window_span):(window - 1)
            @inbounds strand[w] .= survivors_prev
        end

        # fills from current window
        survivors .= strand[window]
        lifespan = 1
    else
#         survivors .= strand[window]
        lifespan += 1
    end
end

In [None]:
compressed_Hunique.CW_typed[1].hapmap

Random.seed!(2020)
happair1 = rand(1:100, 100);
happair2 = rand(1:100, 100);
storage1 = Int[]
storage2 = Int[]
storage3 = Int[]
storage4 = Int[]


# using array of int

Seems like `intersect!` in Base is allocating a lot. Its implementation is confusing.

In [4]:
@btime intersect!(x, y) setup=(x = [1, 2, 3]; y = [1, 4])

  311.573 ns (15 allocations: 1.05 KiB)


1-element Array{Int64,1}:
 1

In [11]:
@which intersect!([1, 2, 3], [1, 4])

## Try writing our own non-allocating intersect

In [13]:
seen = Set{Int}()
# sizehint!(seen, 1000000)

Set(Int64[])

In [14]:
@btime push!($seen, x) setup=(x = rand(1:1000000000))

  5.482 ns (0 allocations: 0 bytes)


Set([568419203, 149786033, 447200868, 213527169, 96429203, 465402303, 925046694, 709520920, 415463448, 971733567  …  738514259, 394646305, 373677655, 115703177, 846435711, 402395238, 86583111, 408774165, 665285694, 722334058])

In [35]:
function Base.intersect!(v::AbstractVector, u::AbstractVector, seen::Set)
    empty!(seen)
    for i in u
        push!(seen, i)
    end
    for i in Iterators.reverse(1:length(v))
        @inbounds v[i] ∉ seen && deleteat!(v, i)
    end
    nothing
end

In [33]:
@btime intersect!(x, y) setup=(x = rand(1:10000, 1000); y = rand(1:10000, 1000));

  70.559 μs (35 allocations: 95.62 KiB)


In [36]:
@btime intersect!(x, y, $seen) setup=(x = rand(1:10000, 1000); y = rand(1:10000, 1000));

  11.087 μs (0 allocations: 0 bytes)
