Skip to content

Commit

Permalink
feat(xRPC): basic stream support
Browse files Browse the repository at this point in the history
Signed-off-by: spacewander <spacewanderlzx@gmail.com>
  • Loading branch information
spacewander committed Apr 19, 2022
1 parent 2518b18 commit 449d933
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 34 deletions.
5 changes: 5 additions & 0 deletions apisix/stream/xrpc/protocols/redis/init.lua
Expand Up @@ -208,6 +208,11 @@ function _M.connect_upstream(session, ctx)
end


function _M.disconnect_upstream(session, upstream, upstream_broken)
sdk.disconnect_upstream(upstream, session.upstream_conf, upstream_broken)
end


function _M.to_upstream(session, ctx, downstream, upstream)
local ok, err = upstream:move(downstream)
if not ok then
Expand Down
86 changes: 65 additions & 21 deletions apisix/stream/xrpc/runner.lua
Expand Up @@ -15,6 +15,8 @@
-- limitations under the License.
--
local core = require("apisix.core")
local pairs = pairs
local ngx = ngx
local ngx_now = ngx.now
local OK = ngx.OK
local DECLINED = ngx.DECLINED
Expand All @@ -27,65 +29,96 @@ local _M = {}
local function open_session(conn_ctx)
conn_ctx.xrpc_session = {
upstream_conf = conn_ctx.matched_upstream,
id_seq = 0,
_ctxs = {},
}
return conn_ctx.xrpc_session
end


local function close_session(session, upstream_broken)
local upstream = session.upstream
if upstream then
if upstream_broken then
upstream:close()
else
upstream:setkeepalive()
end
local function close_session(session, protocol)
local upstream_ctx = session._upstream_ctx
if upstream_ctx then
upstream_ctx.closed = true

local up = upstream_ctx.upstream
protocol.disconnect_upstream(session, up, upstream_ctx.broken)
end

for id in pairs(session._ctxs) do
core.log.info("RPC is not finished, id: ", id)
end
end


local function put_req_ctx(session, ctx)
local id = ctx.id
session.ctxs[id] = nil
session._ctxs[id] = nil

core.tablepool.release("xrpc_ctxs", ctx)
end


local function finish_req(protocol, session, ctx)
ctx.rpc_end_time = ngx_now()
ctx._rpc_end_time = ngx_now()

protocol.log(session, ctx)
put_req_ctx(session, ctx)
end


local function open_upstream(protocol, session, ctx)
if session.upstream then
return OK, session.upstream
if session._upstream_ctx then
return OK, session._upstream_ctx
end

local state, upstream = protocol.connect_upstream(session, session)
if state ~= OK then
return state, nil
end

session.upstream = upstream
return OK, upstream
session._upstream_ctx = {
upstream = upstream,
broken = false,
closed = false,
}
return OK, session._upstream_ctx
end


local function start_upstream_coroutine(session, protocol, downstream, up_ctx)
local upstream = up_ctx.upstream
while not up_ctx.closed do
local status, ctx = protocol.from_upstream(session, downstream, upstream)
if status ~= OK then
if ctx ~= nil then
finish_req(protocol, session, ctx)
end

if status == DECLINED then
-- fail to read
break
end

if status == DONE then
-- a rpc is finished
goto continue
end
end

::continue::
end
end


function _M.run(protocol, conn_ctx)
local session = open_session(conn_ctx)
local downstream = protocol.init_downstream(session)
local upstream_broken = false

while true do
local status, ctx = protocol.from_downstream(session, downstream)
if status ~= OK then
if ctx ~= nil then
finish_req(session, ctx)
finish_req(protocol, session, ctx)
end

if status == DECLINED then
Expand All @@ -100,14 +133,14 @@ function _M.run(protocol, conn_ctx)
end

-- need to do some auth/routing jobs before reaching upstream
local status, upstream = open_upstream(protocol, session, ctx)
local status, up_ctx = open_upstream(protocol, session, ctx)
if status ~= OK then
break
end

status = protocol.to_upstream(session, ctx, downstream, upstream)
status = protocol.to_upstream(session, ctx, downstream, up_ctx.upstream)
if status == DECLINED then
upstream_broken = true
up_ctx.broken = true
break
end

Expand All @@ -116,10 +149,21 @@ function _M.run(protocol, conn_ctx)
goto continue
end

if not up_ctx.coroutine then
local co, err = ngx.thread.spawn(
start_upstream_coroutine, session, protocol, downstream, up_ctx)
if not co then
core.log.error("failed to start upstream coroutine: ", err)
break
end

up_ctx.coroutine = co
end

::continue::
end

close_session(session, upstream_broken)
close_session(session, protocol)

-- return non-zero code to terminal the session
return 200
Expand Down
37 changes: 31 additions & 6 deletions apisix/stream/xrpc/sdk.lua
Expand Up @@ -21,6 +21,7 @@
local core = require("apisix.core")
local xrpc_socket = require("resty.apisix.stream.xrpc.socket")
local ngx_now = ngx.now
local error = error


local _M = {}
Expand All @@ -41,8 +42,10 @@ function _M.connect_upstream(node, up_conf)
core.log.error("failed to connect: ", err)
return nil
end
-- TODO: support timeout

if up_conf.scheme == "tls" then
-- TODO: support mTLS
local ok, err = sk:sslhandshake(nil, node.host)
if not ok then
core.log.error("failed to handshake: ", err)
Expand All @@ -55,22 +58,44 @@ end


---
-- Returns the request level ctx with an optional id
-- Returns disconnected xRPC upstream socket according to the configuration
--
-- @function xrpc.sdk.disconnect_upstream
-- @tparam table xRPC upstream socket
-- @tparam table upstream configuration
-- @tparam boolean is the upstream already broken
function _M.disconnect_upstream(upstream, up_conf, upstream_broken)
if upstream_broken then
upstream:close()
else
-- TODO: support keepalive according to the up_conf
upstream:setkeepalive()
end
end


---
-- Returns the request level ctx with an id
--
-- @function xrpc.sdk.get_req_ctx
-- @tparam table xrpc session
-- @tparam string optional ctx id
-- @treturn table the request level ctx
function _M.get_req_ctx(session, id)
if not id then
id = session.id_seq
session.id_seq = session.id_seq + 1
error("id is required")
end

local ctx = session._ctxs[id]
if ctx then
return ctx
end

local ctx = core.tablepool.fetch("xrpc_ctxs")
session.ctxs[id] = ctx
local ctx = core.tablepool.fetch("xrpc_ctxs", 4, 4)
ctx._id = id
session._ctxs[id] = ctx

ctx.rpc_start_time = ngx_now()
ctx._rpc_start_time = ngx_now()
return ctx
end

Expand Down
73 changes: 66 additions & 7 deletions t/xrpc/apisix/stream/xrpc/protocols/pingpong/init.lua
Expand Up @@ -15,6 +15,7 @@
-- limitations under the License.
--
local core = require("apisix.core")
local sdk = require("apisix.stream.xrpc.sdk")
local xrpc_socket = require("resty.apisix.stream.xrpc.socket")
local bit = require("bit")
local lshift = bit.lshift
Expand All @@ -36,6 +37,7 @@ local _M = {}
local HDR_LEN = 10
local TYPE_HEARTBEAT = 1
local TYPE_UNARY = 2
local TYPE_STREAM = 3


function _M.init_worker()
Expand All @@ -55,9 +57,7 @@ local function read_data(sk, len, body)
local f = body and sk.drain or sk.read
local p, err = f(sk, len)
if not p then
if err == "closed" then
core.log.info("failed to read: ", err)
else
if err ~= "closed" then
core.log.error("failed to read: ", err)
end
return nil
Expand Down Expand Up @@ -98,6 +98,9 @@ function _M.from_downstream(session, downstream)
return DONE
end

local stream_id = p[3] * 256 + p[4]
local ctx = sdk.get_req_ctx(session, stream_id)

local body_len = to_int32(p, 6)
core.log.info("read body len: ", body_len)

Expand All @@ -106,10 +109,11 @@ function _M.from_downstream(session, downstream)
return DECLINED
end

return OK, {
is_unary = typ == TYPE_UNARY,
len = HDR_LEN + body_len
}
ctx.is_unary = typ == TYPE_UNARY
ctx.is_stream = typ == TYPE_STREAM
ctx.id = stream_id
ctx.len = HDR_LEN + body_len
return OK, ctx
end


Expand Down Expand Up @@ -146,6 +150,14 @@ function _M.connect_upstream(session, ctx)
end


function _M.disconnect_upstream(session, upstream, upstream_broken)
-- disconnect upstream created by connect_upstream
-- the upstream_broken flag is used to indicate whether the upstream is
-- already broken
sdk.disconnect_upstream(upstream, session.upstream_conf, upstream_broken)
end


function _M.to_upstream(session, ctx, downstream, upstream)
-- send the request read from downstream to the upstream
-- return whether the request is sent
Expand Down Expand Up @@ -176,6 +188,53 @@ function _M.to_upstream(session, ctx, downstream, upstream)
end


function _M.from_upstream(session, downstream, upstream)
local p = read_data(upstream, HDR_LEN, false)
if p == nil then
return DECLINED
end

local p_b = str_byte("p")
if p[0] ~= p_b or p[1] ~= p_b then
core.log.error("invalid magic number: ", ffi_str(p, 2))
return DECLINED
end

local typ = p[2]
if typ == TYPE_HEARTBEAT then
core.log.info("send heartbeat")

-- need to reset read buf as we won't forward it
downstream:reset_read_buf()
downstream:send(ffi_str(p, HDR_LEN))
return DONE
end

local stream_id = p[3] * 256 + p[4]
local ctx = sdk.get_req_ctx(session, stream_id)

local body_len = to_int32(p, 6)
if body_len ~= ctx.len - HDR_LEN then
core.log.error("upstream body len mismatch, expected: ", ctx.len - HDR_LEN,
", actual: ", body_len)
return DECLINED
end

local p = read_data(upstream, body_len, true)
if p == nil then
return DECLINED
end

local ok, err = downstream:move(upstream)
if not ok then
core.log.error("failed to handle upstream: ", err)
return DECLINED
end

return DONE, ctx
end


function _M.log(session, ctx)
core.log.info("call pingpong's log")
end
Expand Down

0 comments on commit 449d933

Please sign in to comment.