Skip to content

Commit

Permalink
Implement direct reads of cookie/version
Browse files Browse the repository at this point in the history
  • Loading branch information
amitmurthy committed May 23, 2016
1 parent dcde84d commit 1d99afd
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 67 deletions.
6 changes: 5 additions & 1 deletion base/docs/helpdb/Base.jl
Expand Up @@ -7927,7 +7927,11 @@ leading_ones
"""
deserialize(stream)
Read a value written by `serialize`.
Read a value written by `serialize`. `deserialize` assumes the binary data read from
`stream` is correct and has been serialized by a compatible implementation of `serialize`.
It has been designed with simplicity and performance as a goal and does not validate
the data read. Malformed data can result in process termination. The caller has to ensure
the integrity and correctness of data read from `stream`.
"""
deserialize

Expand Down
2 changes: 1 addition & 1 deletion base/initdefs.jl
Expand Up @@ -99,7 +99,7 @@ function init_parallel()
global PGRP
global LPROC
LPROC.id = 1
cluster_cookie(randstring())
cluster_cookie(randstring(HDR_COOKIE_LEN))
assert(isempty(PGRP.workers))
register_worker(LPROC)
end
Expand Down
172 changes: 111 additions & 61 deletions base/multi.jl
Expand Up @@ -61,24 +61,20 @@ end
# Worker initialization messages
type IdentifySocketMsg <: AbstractMsg
from_pid::Int
cookie::AbstractString
end
type IdentifySocketAckMsg <: AbstractMsg
cookie::AbstractString
end
type JoinPGRPMsg <: AbstractMsg
self_pid::Int
other_workers::Array
notify_oid::RRID
topology::Symbol
worker_pool
cookie::AbstractString
end
type JoinCompleteMsg <: AbstractMsg
notify_oid::RRID
cpu_cores::Int
ospid::Int
cookie::AbstractString
end


Expand Down Expand Up @@ -163,19 +159,22 @@ type Worker
w_stream::IO
manager::ClusterManager
config::WorkerConfig
version::Nullable{VersionNumber} # Julia version of the remote process

function Worker(id, r_stream, w_stream, manager, config)
function Worker(id::Int, r_stream::IO, w_stream::IO, manager::ClusterManager;
version=Nullable{VersionNumber}(), config=WorkerConfig())
w = Worker(id)
w.r_stream = r_stream
w.w_stream = buffer_writes(w_stream)
w.manager = manager
w.config = config
w.version = version
set_worker_state(w, W_CONNECTED)
register_worker_streams(w)
w
end

function Worker(id)
function Worker(id::Int)
@assert id > 0
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
Expand All @@ -188,8 +187,6 @@ type Worker
Worker() = Worker(get_next_pid())
end

Worker(id, r_stream, w_stream, manager) = Worker(id, r_stream, w_stream, manager, WorkerConfig())

function set_worker_state(w, state)
w.state = state
notify(w.c_state; all=true)
Expand Down Expand Up @@ -270,6 +267,17 @@ function flush_gc_msgs()
end
end

function send_connection_hdr(w::Worker, cookie=true)
# For a connection initiated from the remote side to us, we only send the version,
# else when we initiate a connection we first send the cookie followed by our version.
# The remote side validates the cookie.

if cookie
write(w.w_stream, LPROC.cookie)
end
write(w.w_stream, rpad(VERSION_STRING, HDR_VERSION_LEN)[1:HDR_VERSION_LEN])
end

## process group creation ##

type LocalProcess
Expand All @@ -282,8 +290,19 @@ end

const LPROC = LocalProcess()

const HDR_VERSION_LEN=16
const HDR_COOKIE_LEN=16
cluster_cookie() = LPROC.cookie
cluster_cookie(cookie) = (LPROC.cookie = cookie; cookie)
function cluster_cookie(cookie)
# The cookie must be an ASCII string with length <= HDR_COOKIE_LEN
assert(isascii(cookie))
assert(length(cookie) <= HDR_COOKIE_LEN)

cookie = rpad(cookie, HDR_COOKIE_LEN)

LPROC.cookie = cookie
cookie
end

const map_pid_wrkr = Dict{Int, Union{Worker, LocalProcess}}()
const map_sock_wrkr = ObjectIdDict()
Expand Down Expand Up @@ -952,40 +971,38 @@ function deliver_result(sock::IO, msg, oid, value)
end

## message event handlers ##
process_messages(r_stream::TCPSocket, w_stream::TCPSocket) = @schedule process_tcp_streams(r_stream, w_stream)

function process_tcp_streams(r_stream::TCPSocket, w_stream::TCPSocket)
disable_nagle(r_stream)
wait_connected(r_stream)
if r_stream != w_stream
disable_nagle(w_stream)
wait_connected(w_stream)
end
message_handler_loop(r_stream, w_stream)
function process_messages(r_stream::TCPSocket, w_stream::TCPSocket, incoming=true)
@schedule process_tcp_streams(r_stream, w_stream, incoming)
end

process_messages(r_stream::IO, w_stream::IO) = @schedule message_handler_loop(r_stream, w_stream)

function message_handler_loop(r_stream::IO, w_stream::IO)
try
# Check for a valid first message with a cookie.
msg = deserialize(r_stream)
if !any(x->isa(msg, x), [JoinPGRPMsg, JoinCompleteMsg, IdentifySocketMsg, IdentifySocketAckMsg]) ||
(msg.cookie != cluster_cookie())
function process_tcp_streams(r_stream::TCPSocket, w_stream::TCPSocket, incoming)
disable_nagle(r_stream)
wait_connected(r_stream)
if r_stream != w_stream
disable_nagle(w_stream)
wait_connected(w_stream)
end
message_handler_loop(r_stream, w_stream, incoming)
end

println(STDERR, "Unknown first message $(typeof(msg)) or cookie mismatch.")
error("Invalid connection credentials.")
end
function process_messages(r_stream::IO, w_stream::IO, incoming=true)
@schedule message_handler_loop(r_stream, w_stream, incoming)
end

function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
try
version = process_hdr(r_stream, incoming)
while true
handle_msg(msg, r_stream, w_stream)
msg = deserialize(r_stream)
# println("got msg: ", msg)
handle_msg(msg, r_stream, w_stream, version)
end
catch e
# println(STDERR, "Process($(myid())) - Exception ", e)
iderr = worker_id_from_socket(r_stream)
if (iderr < 1)
println(STDERR, "Socket from unknown remote worker in worker $(myid())")
println(STDERR, e)
println(STDERR, "Process($(myid())) - Unknown remote, closing connection.")
else
werr = worker_from_id(iderr)
oldstate = werr.state
Expand Down Expand Up @@ -1020,35 +1037,63 @@ function message_handler_loop(r_stream::IO, w_stream::IO)
end
end

handle_msg(msg::CallMsg{:call}, r_stream, w_stream) = schedule_call(msg.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
function handle_msg(msg::CallMsg{:call_fetch}, r_stream, w_stream)
function process_hdr(s, validate_cookie)
if validate_cookie
cookie = read(s, HDR_COOKIE_LEN)
self_cookie = cluster_cookie()
for i in 1:HDR_COOKIE_LEN
if UInt8(self_cookie[i]) != cookie[i]
error("Process($(myid())) - Invalid connection credentials sent by remote.")
end
end
end

# When we have incompatible julia versions trying to connect to each other,
# and can be detected, raise an appropriate error.
# For now, just return the version.
return VersionNumber(strip(String(read(s, HDR_VERSION_LEN))))
end

function handle_msg(msg::CallMsg{:call}, r_stream, w_stream, version)
schedule_call(msg.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
end
function handle_msg(msg::CallMsg{:call_fetch}, r_stream, w_stream, version)
@schedule begin
v = run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), false)
deliver_result(w_stream, :call_fetch, msg.response_oid, v)
end
end

function handle_msg(msg::CallWaitMsg, r_stream, w_stream)
function handle_msg(msg::CallWaitMsg, r_stream, w_stream, version)
@schedule begin
rv = schedule_call(msg.response_oid, ()->msg.f(msg.args...; msg.kwargs...))
deliver_result(w_stream, :call_wait, msg.notify_oid, fetch(rv.c))
end
end

handle_msg(msg::RemoteDoMsg, r_stream, w_stream) = @schedule run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), true)
function handle_msg(msg::RemoteDoMsg, r_stream, w_stream, version)
@schedule run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), true)
end

handle_msg(msg::ResultMsg, r_stream, w_stream) = put!(lookup_ref(msg.response_oid), msg.value)
function handle_msg(msg::ResultMsg, r_stream, w_stream, version)
put!(lookup_ref(msg.response_oid), msg.value)
end

function handle_msg(msg::IdentifySocketMsg, r_stream, w_stream)
function handle_msg(msg::IdentifySocketMsg, r_stream, w_stream, version)
# register a new peer worker connection
w=Worker(msg.from_pid, r_stream, w_stream, cluster_manager)
send_msg_now(w, IdentifySocketAckMsg(cluster_cookie()))
w=Worker(msg.from_pid, r_stream, w_stream, cluster_manager; version=version)
send_connection_hdr(w, false)
send_msg_now(w, IdentifySocketAckMsg())
end

function handle_msg(msg::IdentifySocketAckMsg, r_stream, w_stream, version)
w = map_sock_wrkr[r_stream]
w.version = version
end
handle_msg(msg::IdentifySocketAckMsg, r_stream, w_stream) = nothing

function handle_msg(msg::JoinPGRPMsg, r_stream, w_stream)
function handle_msg(msg::JoinPGRPMsg, r_stream, w_stream, version)
LPROC.id = msg.self_pid
controller = Worker(1, r_stream, w_stream, cluster_manager)
controller = Worker(1, r_stream, w_stream, cluster_manager; version=version)
register_worker(LPROC)
topology(msg.topology)

Expand All @@ -1066,29 +1111,31 @@ function handle_msg(msg::JoinPGRPMsg, r_stream, w_stream)
for wt in wait_tasks; wait(wt); end

set_default_worker_pool(msg.worker_pool)

send_msg_now(controller, JoinCompleteMsg(msg.notify_oid, Sys.CPU_CORES, getpid(), cluster_cookie()))
send_connection_hdr(controller, false)
send_msg_now(controller, JoinCompleteMsg(msg.notify_oid, Sys.CPU_CORES, getpid()))
end

function connect_to_peer(manager::ClusterManager, rpid::Int, wconfig::WorkerConfig)
try
(r_s, w_s) = connect(manager, rpid, wconfig)
w = Worker(rpid, r_s, w_s, manager, wconfig)
process_messages(w.r_stream, w.w_stream)
send_msg_now(w, IdentifySocketMsg(myid(), cluster_cookie()))
w = Worker(rpid, r_s, w_s, manager; config=wconfig)
process_messages(w.r_stream, w.w_stream, false)
send_connection_hdr(w, true)
send_msg_now(w, IdentifySocketMsg(myid()))
catch e
display_error(e, catch_backtrace())
println(STDERR, "Error [$e] on $(myid()) while connecting to peer $rpid. Exiting.")
exit(1)
end
end

function handle_msg(msg::JoinCompleteMsg, r_stream, w_stream)
function handle_msg(msg::JoinCompleteMsg, r_stream, w_stream, version)
w = map_sock_wrkr[r_stream]
environ = get(w.config.environ, Dict())
environ[:cpu_cores] = msg.cpu_cores
w.config.environ = environ
w.config.ospid = msg.ospid
w.version = version

ntfy_channel = lookup_ref(msg.notify_oid)
put!(ntfy_channel, w.id)
Expand Down Expand Up @@ -1128,7 +1175,7 @@ function start_worker(out::IO, cookie::AbstractString)
end
@schedule while isopen(sock)
client = accept(sock)
process_messages(client, client)
process_messages(client, client, true)
end
print(out, "julia_worker:") # print header
print(out, "$(dec(LPROC.bind_port))#") # print port
Expand Down Expand Up @@ -1357,7 +1404,7 @@ function create_worker(manager, wconfig)
w = Worker()

(r_s, w_s) = connect(manager, w.id, wconfig)
w = Worker(w.id, r_s, w_s, manager, wconfig)
w = Worker(w.id, r_s, w_s, manager; config=wconfig)
# install a finalizer to perform cleanup if necessary
finalizer(w, (w)->if myid() == 1 manage(w.manager, w.id, w.config, :finalize) end)

Expand All @@ -1368,19 +1415,21 @@ function create_worker(manager, wconfig)

# Start a new task to handle inbound messages from connected worker in master.
# Also calls `wait_connected` on TCP streams.
process_messages(w.r_stream, w.w_stream)
process_messages(w.r_stream, w.w_stream, false)

# send address information of all workers to the new worker.
# Cluster managers set the address of each worker in `WorkerConfig.connect_at`.
# A new worker uses this to setup a all-to-all network. Workers with higher pids connect to
# workers with lower pids. Except process 1 (master) which initiates connections
# to all workers.
# Flow:
# - master sends (:join_pgrp, list_of_all_worker_addresses) to all workers
# A new worker uses this to setup an all-to-all network if topology :all_to_all is specified.
# Workers with higher pids connect to workers with lower pids. Except process 1 (master) which
# initiates connections to all workers.

# Connection Setup Protocol:
# - Master sends 16-byte cookie followed by 16-byte version string and a JoinPGRP message to all workers
# - On each worker
# - each worker sends a :identify_socket to all workers less than its pid
# - each worker then sends a :join_complete back to the master along with its OS_PID and NUM_CORES
# - once master receives a :join_complete it triggers rr_ntfy_join (signifies that worker setup is complete)
# - Worker responds with a 16-byte version followed by a JoinCompleteMsg
# - Connects to all workers less than its pid. Sends the cookie, version and an IdentifySocket message
# - Workers with incoming connection requests write back their Version and an IdentifySocketAckMsg message
# - On master, receiving a JoinCompleteMsg triggers rr_ntfy_join (signifies that worker setup is complete)

join_list = []
if PGRP.topology == :all_to_all
Expand Down Expand Up @@ -1411,7 +1460,8 @@ function create_worker(manager, wconfig)
end

all_locs = map(x -> isa(x, Worker) ? (get(x.config.connect_at, ()), x.id) : ((), x.id, true), join_list)
send_msg_now(w, JoinPGRPMsg(w.id, all_locs, ntfy_oid, PGRP.topology, default_worker_pool(), cluster_cookie()))
send_connection_hdr(w, true)
send_msg_now(w, JoinPGRPMsg(w.id, all_locs, ntfy_oid, PGRP.topology, default_worker_pool()))

@schedule manage(w.manager, w.id, w.config, :register)
wait(rr_ntfy_join)
Expand Down
2 changes: 1 addition & 1 deletion doc/stdlib/io-network.rst
Expand Up @@ -300,7 +300,7 @@ General I/O

.. Docstring generated from Julia source
Read a value written by ``serialize``\ .
Read a value written by ``serialize``\ . ``deserialize`` assumes the binary data read from ``stream`` is correct and has been serialized by a compatible implementation of ``serialize``\ . It has been designed with simplicity and performance as a goal and does not validate the data read. Malformed data can result in process termination. The caller has to ensure the integrity and correctness of data read from ``stream``\ .

.. function:: print_escaped(io, str::AbstractString, esc::AbstractString)

Expand Down
6 changes: 3 additions & 3 deletions examples/clustermanager/0mq/ZMQCM.jl
Expand Up @@ -181,7 +181,7 @@ function start_master(np)
#println("master recv data from $from_zid")

(r_s, w_s, t_r) = manager.map_zmq_julia[from_zid]
write(r_s, convert(Ptr{Uint8}, data), length(data))
unsafe_write(r_s, pointer(data), length(data))
end
catch e
Base.show_backtrace(STDOUT,catch_backtrace())
Expand All @@ -197,7 +197,7 @@ end
function launch(manager::ZMQCMan, params::Dict, launched::Array, c::Condition)
#println("launch $(params[:np])")
for i in 1:params[:np]
io, pobj = open(`julia worker.jl $i $(Base.cluster_cookie())`, "r")
io, pobj = open(`$(params[:exename]) worker.jl $i $(Base.cluster_cookie())`, "r")

wconfig = WorkerConfig()
wconfig.userdata = Dict(:zid=>i, :io=>io)
Expand Down Expand Up @@ -247,7 +247,7 @@ function start_worker(zid, cookie)
(r_s, w_s, t_r) = streams
end

write(r_s, convert(Ptr{Uint8}, data), length(data))
unsafe_write(r_s, pointer(data), length(data))
end
end

Expand Down

0 comments on commit 1d99afd

Please sign in to comment.