Skip to content
This repository has been archived by the owner on Nov 8, 2022. It is now read-only.

Commit

Permalink
net.jl refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
leonardt committed Apr 6, 2016
1 parent 3ad40a9 commit 443940d
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 93 deletions.
12 changes: 12 additions & 0 deletions src/ensembles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ function init_inputs(ensemble::AbstractEnsemble, net::Net)
end
end

function init_params(ensemble::AbstractEnsemble, net::Net)
if :params in fieldnames(ensemble)
for param in ensemble.params
param.value = get_buffer(net, param.name)
param.gradient = get_buffer(net, param.gradient_name)
param.hist = zeros(param.value)
set_buffer(net, param.hist_name, param.hist)
@latte_mpi param.request = @eval ccall((:init_request, $libComm), Cint, ())
end
end
end

@doc """
Initialize `ensemble` of neuron type `T`
Allocate a buffer for each field in T
Expand Down
221 changes: 128 additions & 93 deletions src/net.jl
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,45 @@ function init_backward(ensemble::NormalizationEnsemble, net::Net,
unshift!(compute_body[ensemble.phase], body)
end

function add_recv_expr(net, source, ensemble, compute_body, compute_args)
key = symbol(source.name, :value)
source_subgroup = source.net_subgroup - 1 # -1 for zero based indexing with MPI ranks
tag = find(x -> x == ensemble, net.ensembles)[1]
expr = :(ccall((:recv_intra, $libComm), Void,
(Ptr{Float32}, Cint, Cint, Cint),
pointer($key), length($key), $tag, $source_subgroup))
if ensemble.phase in [Train, TrainTest] &&
connection.source.phase in [Train, TrainTest]
push!(compute_body[Train], expr)
push!(compute_args[Train], key)
end
if ensemble.phase in [Test, TrainTest] &&
connection.source.phase in [Train, TrainTest]
push!(compute_body[Test], expr)
push!(compute_args[Test], key)
end
end

function add_send_exprs(net, ensemble, compute_body, compute_args)
for (target, tag) in net.ensemble_send_list[ensemble.name]
target = target - 1 # 0-based indexing for MPI
target_phase = net.ensembles[tag].phase
target_buf = symbol(ensemble.name, :value)
expr = :(ccall((:send_intra, $libComm), Void,
(Ptr{Float32}, Cint, Cint, Cint),
pointer($target_buf), length($target_buf), $tag, $target))
if ensemble.phase in [Train, TrainTest] &&
target_phase in [Train, TrainTest]
push!(compute_body[Train], expr)
push!(compute_args[Train], target_buf)
end
if ensemble.phase in [Test, TrainTest] &&
target_phase in [Test, TrainTest]
push!(compute_body[Test], expr)
push!(compute_args[Test], target_buf)
end
end
end

function init(net::Net)
log_info("Initializing net...")
Expand All @@ -599,70 +638,58 @@ function init(net::Net)
backward_compute_body = Dict{Phase, Vector}(Train => [], Test => [])
backward_compute_args = ArgSet()
seen_names = Set()
# Initialize ensembles
log_info(" Initializing ensembles.")
for ensemble in net.ensembles
# Check for duplicate ensemble names
if ensemble.name in seen_names
throw("Error: Found duplicate ensemble name: $(ensemble.name)")
end
@latte_mpi (if ensemble.net_subgroup != get_net_subrank(net) + 1
continue # skip
end
)
push!(seen_names, ensemble.name)

# If in MPI mode skip ensembles not assigned to this subrank
@latte_mpi(if ensemble.net_subgroup != get_net_subrank(net) + 1
continue # skip
end)
net.ensemble_send_list[ensemble.name] = Tuple{Int, Int}[]
map(init, ensemble)

log_info(" $(ensemble.name) size=$(size(ensemble))")

# Initialize the ensemble
init(ensemble, net)
init_params(ensemble, net)
end

for (index, ensemble) in enumerate(net.ensembles)
@latte_mpi (if ensemble.net_subgroup != get_net_subrank(net) + 1
for connection in ensemble.connections
if connection.source.net_subgroup == get_net_subrank(net) + 1
push!(net.ensemble_send_list[connection.source.name],
(ensemble.net_subgroup, index))
end
# If in MPI mode, populate the send_list for connected ensembles not in
# this subrank and skip initializations
@latte_mpi if ensemble.net_subgroup != get_net_subrank(net) + 1
for connection in ensemble.connections
if connection.source.net_subgroup == get_net_subrank(net) + 1
push!(net.ensemble_send_list[connection.source.name],
(ensemble.net_subgroup, index))
end
continue # skip
end
)
continue # skip
end

# Initialize the inputs for the ensemble, this is done after all
# ensembles have initialized their outputs (necessary for rnns)
init_inputs(ensemble, net)
end
log_info(" Finished initializing ensembles.")

log_info(" Synthesizing forward functions.")
# Generate forward tasks
for ensemble in net.ensembles
@latte_mpi (if ensemble.net_subgroup != get_net_subrank(net) + 1
@latte_mpi if ensemble.net_subgroup != get_net_subrank(net) + 1
continue # skip
end)
# TODO: Should param initialization be done in ensemble init??
if :params in fieldnames(ensemble)
for param in ensemble.params
param.value = get_buffer(net, param.name)
param.gradient = get_buffer(net, param.gradient_name)
param.hist = zeros(param.value)
set_buffer(net, param.hist_name, param.hist)
@latte_mpi param.request = @eval ccall((:init_request, $libComm), Cint, ())
end
end

for connection in ensemble.connections
@latte_mpi for connection in ensemble.connections
if connection.source.net_subgroup != ensemble.net_subgroup
key = symbol(connection.source.name, :value)
source_subgroup = connection.source.net_subgroup - 1 # -1 for zero based indexing with MPI ranks
tag = find(x -> x == ensemble, net.ensembles)[1]
expr = :(ccall((:recv_intra, $libComm), Void, (Ptr{Float32}, Cint, Cint, Cint), pointer($key), length($key), $tag, $source_subgroup))
if ensemble.phase in [Train, TrainTest] &&
connection.source.phase in [Train, TrainTest]
push!(forward_compute_body[Train], expr)
push!(forward_compute_args[Train], key)
end
if ensemble.phase in [Test, TrainTest] &&
connection.source.phase in [Train, TrainTest]
push!(forward_compute_body[Test], expr)
push!(forward_compute_args[Test], key)
end
add_recv_expr(net, connection.source, ensemble,
forward_compute_body, forward_compute_args)

end
end
# Skip code generation for data and loss ensembles
Expand All @@ -677,24 +704,9 @@ function init(net::Net)
else
throw("Latte Error: Encountered unsupported ensemble type $(typeof(ensemble)).")
end
for (target, tag) in net.ensemble_send_list[ensemble.name]
target = target - 1 # 0-based indexing for MPI
target_phase = net.ensembles[tag].phase
target_buf = symbol(ensemble.name, :value)
expr = :(ccall((:send_intra, $libComm), Void,
(Ptr{Float32}, Cint, Cint, Cint),
pointer($target_buf), length($target_buf), $tag, $target))
if ensemble.phase in [Train, TrainTest] &&
target_phase in [Train, TrainTest]
push!(forward_compute_body[Train], expr)
push!(forward_compute_args[Train], target_buf)
end
if ensemble.phase in [Test, TrainTest] &&
target_phase in [Test, TrainTest]
push!(forward_compute_body[Test], expr)
push!(forward_compute_args[Test], target_buf)
end
end

@latte_mpi add_send_exprs(net, ensemble, forward_compute_body,
forward_compute_args)
end
append!(net.forward_tasks, forward_data_tasks)

Expand Down Expand Up @@ -773,27 +785,36 @@ function init(net::Net)
log_info("Initialization finished.")
end

function get_task_args(net, task_args, t)
args = []
for arg in task_args
if isa(arg, Symbol)
push!(args, get_buffer(net, arg, t))
else
push!(args, arg)
end
end
args
end

function run_task(task::JuliaTask)
end

# Use metaprogramming to generate single and multi versions of forward
# and backward.
for direction in [:forward, :backward]
tasks = symbol(direction,:_tasks)
@eval function $direction(net::Net; phase=Train, solver=nothing)
for t = 1:net.time_steps
net.curr_time_step = t
for (index, task) in enumerate(net.$tasks[phase])
for task in net.$tasks[phase]
if isa(task, JuliaTask)
args = []
for arg in task.args
if isa(arg, Symbol)
push!(args, get_buffer(net, arg, t))
else
push!(args, arg)
end
end
task.func(args...)
task.func(get_task_args(net, task.args, t)...)
elseif isa(task, UpdateTask)
@assert phase == Train && solver != nothing
update(solver, net, task.param_id)
else
throw("Unsupported task type $(typeof(task))")
end
end
end
Expand All @@ -808,6 +829,38 @@ function add_ensemble(net::Net, ens::AbstractEnsemble)
net.ensembles_map[ens.name] = ens
end

function check_dimensions_fixed(mapping::Function, sink_shape)
n = length(sink_shape)
is_dim_fixed = [true for _ in 1:n]
first = mapping(ones(Int, n)...)
if !all(map((x) -> isa(x, UnitRange) || isa(x, Colon), first))
is_dim_fixed = [false for _ in 1:n]
else
for d in 1:n
for i in 1:sink_shape[d]
idx = ones(Int, n)
idx[d] = i
if first != mapping(idx...)
is_dim_fixed[d] = false
break
end
end
end
end
is_dim_fixed
end

function check_one_to_one(mapping, shape)
is_one_to_one = true
for i in CartesianRange(shape)
if mapping(i.I...) != i.I
is_one_to_one = false
break
end
end
is_one_to_one
end

"""
Connect neurons in `source` to neurons in `sink` using the function `mapping`.
`mapping` should be a function with a parameter for the index in each dimension
Expand All @@ -821,6 +874,8 @@ current index
function add_connections(net::Net, source::AbstractEnsemble,
sink::AbstractEnsemble, mapping::Function; padding=0, recurrent=false)
n = ndims(sink)

# Compute the size and shape of the connection
range_size = 1
range_shape = []
for (index, d) in enumerate(mapping(ones(Int, n)...))
Expand All @@ -832,32 +887,12 @@ function add_connections(net::Net, source::AbstractEnsemble,
push!(range_shape, length(d))
end
end
is_dim_fixed = [true for _ in 1:n]
first = mapping(ones(Int, n)...)
if !all(map((x) -> isa(x, UnitRange) || isa(x, Colon), first))
is_dim_fixed = [false for _ in 1:n]
else
for d in 1:n
for i in 1:size(sink, d)
idx = ones(Int, n)
idx[d] = i
if first != mapping(idx...)
is_dim_fixed[d] = false
break
end
end
end
end
is_one_to_one = true

# Determine if any dimensions are fixed
is_dim_fixed = check_dimensions_fixed(mapping, size(sink))
is_one_to_one = false
if !all(is_dim_fixed)
for i in CartesianRange(size(sink))
if mapping(i.I...) != i.I
is_one_to_one = false
break
end
end
else
is_one_to_one = false
is_one_to_one = check_one_to_one(mapping, size(sink))
end
push!(sink.connections, Connection(source, mapping, tuple(range_shape...),
range_size, true, is_dim_fixed,
Expand Down

0 comments on commit 443940d

Please sign in to comment.