Skip to content

Commit

Permalink
Adds a cache flag to delayed - result will be cached until a memo…
Browse files Browse the repository at this point in the history
…ry limit

is reached on each worker. Default memory limit is 50% of memory / number of workers in the node
  • Loading branch information
Shashi Gowda committed May 2, 2017
1 parent d27a967 commit dee4971
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 32 deletions.
35 changes: 25 additions & 10 deletions src/chunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ chunktype(c::Chunk) = c.chunktype
persist!(t::Chunk) = (t.persist=true; t)
shouldpersist(p::Chunk) = t.persist
affinity(c::Chunk) = affinity(c.handle)
function unrelease{T}(c::Chunk{T,MemToken})
if unrelease_token(c.handle)
Nullable{Any}(c)
else
Nullable{Any}()
end
end
unrelease(c::Chunk) = c

function gather(ctx, chunk::Chunk)
# delegate fetching to handle by default.
Expand All @@ -47,15 +55,22 @@ end


### ChunkIO
gather(ctx, ref::MemToken) = fetch(ref)
affinity(c::MemToken) = [OSProc(c.where)]
function gather(ctx, ref::MemToken)
res = fetch(ref)
if isnull(res)
throw(KeyError(ref))
else
get(res)
end
end
affinity(c::MemToken) = [OSProc(c.where)=>c.size]

"""
Create a chunk from a sequential object.
"""
function tochunk(x; persist=false)
ref = make_token(x)
Chunk(typeof(x), domain(x), ref, true)
Chunk(typeof(x), domain(x), ref, persist)
end
tochunk(x::AbstractChunk) = x

Expand Down Expand Up @@ -88,7 +103,7 @@ function gather{X}(ctx, s::View{Chunk{X, MemToken}})
ref = s.chunk.handle
pid = ref.where
let d = s.subdomain
remotecall_fetch(x -> fetch(x)[d], pid, ref)
remotecall_fetch(x -> get(fetch(x))[d], pid, ref)
end
end

Expand Down Expand Up @@ -203,20 +218,20 @@ function lookup_parts{N}(ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDo
pieces, out_dmn
end

function free!(x::Cat, force=true)
function free!(x::Cat; force=true, cache=false)
for p in chunks(x)
free!(p, force)
free!(p, force=force, cache=cache)
end
end
# Check to see if the node is set to persist
# if it is foce can override it
function free!{X}(s::Chunk{X, MemToken}, force=true)
function free!{X}(s::Chunk{X, MemToken}; force=true, cache=false)
if force || !s.persist
release_token(s.handle)
release_token(s.handle, cache)
end
end
free!(s::AbstractChunk, force=true) = nothing
free!(s::View, force=true) = nothing
free!(s::View; force=true, cache=false) = nothing
free!(x; force=true,cache=false) = x # catch-all for non-chunks


Base.@deprecate_binding AbstractPart AbstractChunk
Expand Down
40 changes: 32 additions & 8 deletions src/compute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ function stage(ctx, x::Cached)
x
end

free!(x::Computed, force=true) = free!(x.result,force)
free!(x::Computed; force=true, cache=false) = free!(x.result,force=force, cache=cache)
function finalize_computed!(x::Computed)
@schedule free!(x, true) # @schedule needed because gc can't yield
@schedule free!(x; force=true) # @schedule needed because gc can't yield
end

gather(ctx, x::Computed) = gather(ctx, x.result)
Expand Down Expand Up @@ -151,6 +151,9 @@ thunkize(ctx, x::Thunk) = x
function finish_task!(state, node, node_order; free=true)
deps = sort([i for i in state[:dependents][node]], by=node_order)
immediate_next = false
if istask(node) && node.cache
node.cache_ref = Nullable{Any}(state[:cache][node])
end
for dep in deps
set = state[:waiting][dep]
pop!(set, node)
Expand All @@ -171,7 +174,7 @@ function finish_task!(state, node, node_order; free=true)
if free && isempty(s)
if haskey(state[:cache], inp)
_node = state[:cache][inp]
free!(_node, false)
free!(_node, force=false, cache=(istask(inp) && inp.cache))
pop!(state[:cache], inp)
end
end
Expand All @@ -182,8 +185,6 @@ function finish_task!(state, node, node_order; free=true)
immediate_next
end

free!(x, force=true) = x # catch-all for non-chunks

###### Scheduler #######
"""
Compute a Thunk - creates the DAG, assigns ranks to
Expand Down Expand Up @@ -256,8 +257,31 @@ function pop_with_affinity!(tasks, proc)
end

function fire_task!(ctx, thunk, proc, state, chan, node_order)
@logmsg("W$(proc.pid) + $thunk ($(thunk.f)) input:$(thunk.inputs)")
@logmsg("W$(proc.pid) + $thunk ($(thunk.f)) input:$(thunk.inputs) cache:$(thunk.cache) $(thunk.cache_ref)")
push!(state[:running], thunk)
if thunk.cache && !isnull(thunk.cache_ref)
# the result might be already cached
data = unrelease(get(thunk.cache_ref)) # ask worker to keep the data around
# till this compute cycle frees it
if !isnull(data)
@logmsg("cache hit: $(get(thunk.cache_ref))")
state[:cache][thunk] = get(data)
immediate_next = finish_task!(state, thunk, node_order; free=false)
if !isempty(state[:ready])
if immediate_next
thunk = pop!(state[:ready])
else
thunk = pop_with_affinity!(state[:ready], proc)
end
fire_task!(ctx, thunk, proc, state, chan, node_order)
end
return
else
thunk.cache_ref = Nullable{Any}()
@logmsg("cache miss: $(thunk.cache_ref)")
end
end

if thunk.meta
# Run it on the parent node
# do not _move data.
Expand Down Expand Up @@ -333,8 +357,8 @@ Given a root node of the DAG, calculates a total order for tie-braking
i.e. total number of tasks depending on the result of the said node.
Args:
- ndeps: result of `noffspring`
- node: root node
- ndeps: result of `noffspring`
"""
function order(node::Thunk, ndeps)
order([node], ndeps, 0)[2]
Expand Down Expand Up @@ -383,7 +407,7 @@ _move(ctx, to_proc::OSProc, x::AbstractChunk) = gather(ctx, x)

function do_task(ctx, proc, thunk_id, f, data, send_result, persist)
@dbg timespan_start(ctx, :comm, thunk_id, proc)
fetched = map(x->_move(ctx, proc, x), data)
time_cost = @elapsed fetched = map(x->_move(ctx, proc, x), data)
@dbg timespan_end(ctx, :comm, thunk_id, proc)

@dbg timespan_start(ctx, :compute, thunk_id, proc)
Expand Down
79 changes: 71 additions & 8 deletions src/lib/dumbref.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,97 @@
immutable MemToken
where::Int
key::Int
size::Int # size in bytes
released::Bool # has this been released but asked to be cached
end

global _mymem = Dict{MemToken,Any}()
const MAX_MEMORY = Ref{Float64}((Sys.total_memory() / nprocs()) / 2) # half the process's share
const _mymem = Dict{Int,Any}()
const _token_order = MemToken[]

let token_count = 0
global next_token_id
next_token_id() = (token_count+=1)
end

function data_size(d)
Base.summarysize(d)
end

function data_size(xs::AbstractArray{String})
# doesn't check for redundant references, but
# really super fast in comparison to summarysize
sum(map(sizeof, xs))
end

function make_token(data)
tok = MemToken(myid(), next_token_id())
_mymem[tok] = data
sz = data_size(data)
tok = MemToken(myid(), next_token_id(), sz, false)
total_size = sum(map(x->x.size, _token_order)) + sz

i = 1
while total_size > MAX_MEMORY[] && i <= length(_token_order)
# we need to weed out some old data here
# if everything that can be has been cleaned up
t = _token_order[i]
if t.released
filter!(x->x.key != t.key, _token_order)
x = pop!(_mymem, t.key)
total_size -= t.size
@logmsg("cached & released $t - $(t.size)B dropped")
end
i += 1
end
push!(_token_order, tok)
_mymem[tok.key] = data
tok
end

function release_token(tok)
function release_token(tok, keeparound=false)
if tok.where == myid()
x = pop!(_mymem, tok)
@logmsg("removed $tok - $(sizeof(x))B freed")
if keeparound
# set released to true, but don't remove it yet.
tok_released = MemToken(tok.where, tok.key, tok.size, true)
idx = find(x->x.key == tok.key, _token_order)
_token_order[idx] = tok_released
@logmsg("soft-released $tok - $(tok.size)B freed")
else
filter!(x->x.key != tok.key, _token_order)
x = pop!(_mymem, tok.key)
@logmsg("removed $tok - $(tok.size)B freed")
end
else
remotecall_fetch(()->release_token(tok), tok.where)
remotecall_fetch(()->release_token(tok, keeparound), tok.where)
end
nothing
end

function Base.fetch(t::MemToken)
if t.where == myid()
_mymem[t]
if haskey(_mymem, t.key)
return Nullable{Any}(_mymem[t.key])
else
return Nullable{Any}()
end
else
remotecall_fetch(()->fetch(t), t.where)
end
end

function unrelease_token(tok)
# first we need to check if the token was removed due cache pruning

if tok.where == myid()
# set released to true, but don't remove it yet.
if !haskey(_mymem, tok.key)
return false
end
tok_unreleased = MemToken(tok.where, tok.key, tok.size, false)
idx = find(x->x.key == tok.key, _token_order)
_token_order[idx] = tok_unreleased
true
else
remotecall_fetch(()->unrelease_token(tok), tok.where)
end
end

25 changes: 19 additions & 6 deletions src/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,46 @@ type Thunk <: AbstractChunk
id::Int
get_result::Bool # whether the worker should send the result or only the metadata
meta::Bool
persist::Bool
persist::Bool # don't `free!` result after computing
cache::Bool # release the result giving the worker an opportunity to
# cache it
cache_ref::Nullable
function Thunk(f, xs...;
id::Int=next_id(),
get_result::Bool=false,
meta::Bool=false,
persist::Bool=false)
thunk = new(f,xs,id,get_result,meta,persist)
persist::Bool=false,
cache::Bool=false,
cache_ref::Nullable{Any}=Nullable{Any}(),
)
thunk = new(f,xs,id,get_result,meta,persist, cache, cache_ref)
_thunk_dict[id] = thunk
thunk
end
end

function affinity(t::Thunk)
aff = []
if t.cache && !isnull(t.cache_ref)
affinity(get(t.cache_ref))
end
aff = Dict{Processor,Int}()
for inp in inputs(t)
if isa(inp, AbstractChunk)
aff = vcat(aff, affinity(inp))
for a in affinity(inp)
proc, sz = a
aff[proc] = get(aff, proc, 0) + sz
end
end
end
aff
sort!(collect(aff), by=last,rev=true)
end

function delayed(f; kwargs...)
(args...) -> Thunk(f, args...; kwargs...)
end

persist!(t::Thunk) = (t.persist=true; t)
cache_result!(t::Thunk) = (t.cache=true; t)

# @generated function compose{N}(f, g, t::NTuple{N})
# if N <= 4
Expand Down
28 changes: 28 additions & 0 deletions test/cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using Base.Test
using Dagger
@testset "cache" begin
# set available memory to 8MB on each worker
test_extra = 8*10^6
map(workers()) do pid
pid=>remotecall_fetch(pid) do
totsz = sum(map(x->x.size, Dagger._token_order))
Dagger.MAX_MEMORY[] = totsz + test_extra
end
end

thunks1 = map(delayed(_ -> (println(myid()); rand(10^5)), cache=true), workers())
sum1 = delayed((x...)->sum([x...]))(map(delayed(sum), thunks1)...)
thunks2 = map(delayed(-), thunks1)
sum2 = delayed((x...)->sum([x...]))(map(delayed(sum), thunks2)...)
s1 = gather(sum1)
@test -s1 == gather(sum2)
@test s1 == gather(sum1)
@test -gather(sum1) == gather(sum2)

thunks1 = map(delayed(_ -> rand(10^6), cache=true), workers())
sum1 = delayed((x...)->sum([x...]))(map(delayed(sum), thunks1)...)
thunks2 = map(delayed(-), thunks1)
sum2 = delayed((x...)->sum([x...]))(map(delayed(sum), thunks2)...)
s1 = gather(sum1) # this should evict thunk1s from memory
@test -s1 != gather(sum2)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ using Dagger

include("domain.jl")
include("array.jl")
include("cache.jl")

0 comments on commit dee4971

Please sign in to comment.