-
Notifications
You must be signed in to change notification settings - Fork 20
/
subsample.jl
41 lines (38 loc) · 1.28 KB
/
subsample.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
subsample(out::AbstractVector, mask::AbstractVector, rowsample::AbstractFloat)
Returns a view of selected rows ids.
"""
function subsample(is_in::AbstractVector, is_out::AbstractVector, mask::AbstractVector, rowsample::AbstractFloat, rng)
Random.rand!(rng, mask)
cond = round(UInt8, 255 * rowsample)
chunk_size = cld(length(is_in), min(cld(length(is_in), 1024), Threads.nthreads()))
nblocks = cld(length(is_in), chunk_size)
counts = zeros(Int, nblocks)
@threads for bid = 1:nblocks
i_start = chunk_size * (bid - 1) + 1
i_stop = bid == nblocks ? length(is_in) : i_start + chunk_size - 1
count = 0
i = i_start
for i = i_start:i_stop
if mask[i] <= cond
is_in[i_start+count] = i
count += 1
end
end
counts[bid] = count
end
counts_cum = cumsum(counts) .- counts
@threads for bid = 1:nblocks
count_cum = counts_cum[bid]
i_start = chunk_size * (bid - 1)
@inbounds for i = 1:counts[bid]
is_out[count_cum+i] = is_in[i_start+i]
end
end
counts_sum = sum(counts)
if counts_cum == 0
@error "no subsample observation - choose larger rowsample"
else
return view(is_out, 1:counts_sum)
end
end