Skip to content

Commit

Permalink
Merge pull request #179 from ably/feature/rtn15h-recv-disconnected
Browse files Browse the repository at this point in the history
RTN15h*: Handle incoming DISCONNECTED while CONNECTED
  • Loading branch information
tcard committed Aug 14, 2020
2 parents aeb92dd + 777c45a commit da47812
Show file tree
Hide file tree
Showing 5 changed files with 406 additions and 72 deletions.
29 changes: 25 additions & 4 deletions ably/ablytest/recorders.go
Expand Up @@ -289,18 +289,35 @@ func (rec *StateRecorder) timeout() time.Duration {
return 15 * time.Second
}

func MessagePipe(in <-chan *proto.ProtocolMessage, out chan<- *proto.ProtocolMessage) func(string, *url.URL) (proto.Conn, error) {
type MessagePipeOption func(*pipeConn)

// MessagePipeWithNowFunc sets a function to get the current time. This time
// will be used to determine whether a Receive times out.
//
// If not set, receives won't timeout.
func MessagePipeWithNowFunc(now func() time.Time) MessagePipeOption {
return func(pc *pipeConn) {
pc.now = now
}
}

func MessagePipe(in <-chan *proto.ProtocolMessage, out chan<- *proto.ProtocolMessage, opts ...MessagePipeOption) func(string, *url.URL) (proto.Conn, error) {
return func(proto string, u *url.URL) (proto.Conn, error) {
return pipeConn{
pc := pipeConn{
in: in,
out: out,
}, nil
}
for _, opt := range opts {
opt(&pc)
}
return pc, nil
}
}

type pipeConn struct {
in <-chan *proto.ProtocolMessage
out chan<- *proto.ProtocolMessage
now func() time.Time
}

func (pc pipeConn) Send(msg *proto.ProtocolMessage) error {
Expand All @@ -309,13 +326,17 @@ func (pc pipeConn) Send(msg *proto.ProtocolMessage) error {
}

func (pc pipeConn) Receive(deadline time.Time) (*proto.ProtocolMessage, error) {
var timeout <-chan time.Time
if pc.now != nil {
timeout = time.After(deadline.Sub(pc.now()))
}
select {
case m, ok := <-pc.in:
if !ok {
return nil, io.EOF
}
return m, nil
case <-time.After(time.Until(deadline)):
case <-timeout:
return nil, errTimeout{}
}
}
Expand Down
40 changes: 20 additions & 20 deletions ably/realtime_client.go
Expand Up @@ -29,7 +29,10 @@ func NewRealtime(options ClientOptions) (*Realtime, error) {
c.Auth = rest.Auth
c.Channels = newChannels(c)
conn, err := newConn(c.opts(), rest.Auth, connCallbacks{
c.onChannelMsg, c.onReconnectMsg, c.onConnStateChange,
c.onChannelMsg,
c.onReconnected,
c.onReconnectionFailed,
c.onConnStateChange,
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -64,30 +67,27 @@ func (c *Realtime) onChannelMsg(msg *proto.ProtocolMessage) {
c.Channels.Get(msg.Channel).notify(msg)
}

func (c *Realtime) onReconnectMsg(msg *proto.ProtocolMessage, isNewID bool) {
switch msg.Action {
case proto.ActionConnected:
if msg.Error != nil || // RTN15c3
isNewID { // RTN15g3
for _, ch := range c.Channels.All() {
switch ch.State() {
// TODO: SUSPENDED
case StateChanAttaching, StateChanAttached:
ch.mayAttach(false, false)
}
}
}

case proto.ActionError:
// (RTN15c4)
func (c *Realtime) onReconnected(err *proto.ErrorInfo, isNewID bool) {
if err == nil /* RTN15c3 */ && !isNewID /* RTN15g3 */ {
return
}

for _, ch := range c.Channels.All() {
ch.state.syncSet(StateChanFailed, newErrorProto(msg.Error))
for _, ch := range c.Channels.All() {
switch ch.State() {
// TODO: SUSPENDED
case StateChanAttaching, StateChanAttached:
ch.mayAttach(false, false)
}
}
}

func tokenError(err *proto.ErrorInfo) bool {
func (c *Realtime) onReconnectionFailed(err *proto.ErrorInfo) {
for _, ch := range c.Channels.All() {
ch.state.syncSet(StateChanFailed, newErrorProto(err))
}
}

func isTokenError(err *proto.ErrorInfo) bool {
return err.StatusCode == http.StatusUnauthorized && (40140 <= err.Code && err.Code < 40150)
}

Expand Down
139 changes: 94 additions & 45 deletions ably/realtime_conn.go
Expand Up @@ -48,19 +48,22 @@ type Connection struct {
// with this set to true then its the first message/response after issuing the
// reconnection request.
reconnecting bool
// reauthorizing tracks if the current reconnection attempt is happening
// after a reauthorization, to avoid re-reauthorizing.
reauthorizing bool
}

type connCallbacks struct {
onChannelMsg func(*proto.ProtocolMessage)
// onReconnectMsg is called when we get a response from reconnect request. We
// onReconnected is called when we get a CONNECTED response from reconnect request. We
// move this up because some implementation details for (RTN15c) requires
// access to Channels and we dont have it here so we let RealtimeClient do the
// work.
onReconnectMsg func(_ *proto.ProtocolMessage, isNewID bool)
onStateChange func(State)
// reconnecting tracks if we have issued a reconnection request. If we receive any message
// with this set to true then its the first message/response after issuing the
onReconnected func(_ *proto.ErrorInfo, isNewID bool)
// onReconnectionFailed is called when we get a FAILED response from a
// reconnection request.
onReconnectionFailed func(*proto.ErrorInfo)
onStateChange func(State)
}

func newConn(opts *clientOptions, auth *Auth, callbacks connCallbacks) (*Connection, error) {
Expand Down Expand Up @@ -97,6 +100,12 @@ func (c *Connection) dial(proto string, u *url.URL) (proto.Conn, error) {
// Connect attempts to move the connection to the CONNECTED state, if it
// can and if it isn't already.
func (c *Connection) Connect() {
c.state.Lock()
isActive := c.isActive()
c.state.Unlock()
if isActive {
return
}
c.connect(false)
}

Expand Down Expand Up @@ -200,10 +209,9 @@ func (c *Connection) params(mode connectionMode) (url.Values, error) {
func (c *Connection) connectWith(result bool, mode connectionMode) (Result, error) {
c.state.Lock()
defer c.state.Unlock()
if c.isActive() {
return nopResult, nil
if !c.isActive() {
c.setState(StateConnConnecting, nil)
}
c.setState(StateConnConnecting, nil)
u, err := url.Parse(c.opts.realtimeURL())
if err != nil {
return nil, c.setState(StateConnFailed, err)
Expand Down Expand Up @@ -537,84 +545,89 @@ func (c *Connection) eventloop() {
c.callbacks.onChannelMsg(msg)
break
}

c.state.Lock()
if c.reconnecting {
c.reconnecting = false
if tokenError(msg.Error) {
// (RTN15c5)
// TODO: (gernest) implement (RTN15h) This can be done as a separate task?
reauthorizing := c.reauthorizing
c.reauthorizing = false
if isTokenError(msg.Error) {
if reauthorizing {
c.lockedReauthorizationFailed(newErrorProto(msg.Error))
c.state.Unlock()
return
} else {
// (RTN15c4)
c.callbacks.onReconnectMsg(msg, false)
// TODO: RTN14b; may reuse c.reauthorize from RTN15h2.
}
}
c.setState(StateConnFailed, newErrorProto(msg.Error))
c.state.Unlock()
c.queue.Fail(newErrorProto(msg.Error))
if c.conn != nil {
c.conn.Close()
}

c.failedConnSideEffects(msg.Error)
case proto.ActionConnected:
c.state.Lock()
// we need to get this before we set c.key so as to be sure if we were
// resuming or recovering the connection.
mode := c.getMode()
c.state.Unlock()
if msg.ConnectionDetails != nil {
connDetails = msg.ConnectionDetails

c.state.Lock()
c.key = connDetails.ConnectionKey //(RTN15e) (RTN16d)
c.state.Unlock()

// Spec RSA7b3, RSA7b4, RSA12a
c.auth.updateClientID(connDetails.ClientID)
}
c.state.Lock()
reconnecting := c.reconnecting
if reconnecting {
// reset the mode
c.reconnecting = false
c.reauthorizing = false
}
id := c.id
previousID := c.id
c.id = msg.ConnectionID
c.msgSerial = 0
if reconnecting && mode == recoveryMode {
// we are setting msgSerial as per (RTN16f)
msgSerial, err := strconv.ParseInt(strings.Split(c.opts.Recover, ":")[2], 10, 64)
if err != nil {
//TODO: how to handle this? Panic?
}
c.msgSerial = msgSerial
}
c.setSerial(-1)
c.state.Unlock()
if reconnecting {
// (RTN15c1) (RTN15c2)
c.state.Lock()
c.setState(StateConnConnected, newErrorProto(msg.Error))
c.state.Unlock()
if id != msg.ConnectionID {
if previousID != msg.ConnectionID {
// (RTN15c3)
// we are calling this outside of locks to avoid deadlock because in the
// RealtimeClient client where this callback is implemented we do some ops
// with this Conn where we re acquire Conn.state.Lock again.
c.callbacks.onReconnectMsg(msg, true)
c.callbacks.onReconnected(msg.Error, true)
}
} else {
// preserve old behavior.
c.state.Lock()
c.setState(StateConnConnected, nil)
c.state.Unlock()
}
c.state.Lock()
c.id = msg.ConnectionID
c.msgSerial = 0
if reconnecting && mode == recoveryMode {
// we are setting msgSerial as per (RTN16f)
msgSerial, err := strconv.ParseInt(strings.Split(c.opts.Recover, ":")[2], 10, 64)
if err != nil {
//TODO: how to handle this? Panic?
}
c.msgSerial = msgSerial
}
c.setSerial(-1)
c.state.Unlock()
c.queue.Flush()
case proto.ActionDisconnected:
c.state.Lock()
c.id = ""
c.setState(StateConnDisconnected, nil)
c.state.Unlock()
if !isTokenError(msg.Error) {
// The spec doesn't say what to do in this case, so do nothing.
// Ably is supposed to then close the transport, which will
// trigger a transition to DISCONNECTED.
continue
}

if !c.auth.isTokenRenewable() {
// RTN15h1
c.failedConnSideEffects(msg.Error)
return
}

// RTN15h2
c.reauthorize(lastActivityAt, connDetails)
return
case proto.ActionClosed:
c.state.Lock()
c.id, c.key = "", "" //(RTN16c)
Expand All @@ -641,6 +654,42 @@ func (c *Connection) setState(state StateEnum, err error) error {
return c.state.set(state, err)
}

func (c *Connection) failedConnSideEffects(err *proto.ErrorInfo) {
c.state.Lock()
if c.reconnecting {
c.reconnecting = false
c.reauthorizing = false
c.callbacks.onReconnectionFailed(err)
}
c.setState(StateConnFailed, newErrorProto(err))
c.state.Unlock()
c.queue.Fail(newErrorProto(err))
if c.conn != nil {
c.conn.Close()
}
}

func (c *Connection) reauthorize(lastActivityAt time.Time, connDetails *proto.ConnectionDetails) {
c.state.Lock()
_, err := c.auth.reauthorize()
if err != nil {
c.lockedReauthorizationFailed(err)
c.state.Unlock()
return
}

// The reauthorize above will have set the new token in c.auth, so
// reconnecting will use the new token.
c.reauthorizing = true
c.state.Unlock()
c.reconnect(lastActivityAt, connDetails, false)
}

func (c *Connection) lockedReauthorizationFailed(err error) {
c.setState(StateConnDisconnected, err)
// TODO: RTN14d
}

type verboseConn struct {
conn proto.Conn
logger *LoggerOptions
Expand Down

0 comments on commit da47812

Please sign in to comment.