Skip to content

Commit

Permalink
feat: support check conn liveness (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
dk-lockdown committed Aug 29, 2022
1 parent 1d2d32d commit 59b9ff1
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
11 changes: 10 additions & 1 deletion pkg/driver/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,16 @@ func (conn *BackendConnection) Connect(ctx context.Context) error {
} else {
typ = conn.conf.Net
}
netConn, err := net.Dial(typ, conn.conf.Addr)

var (
netConn net.Conn
err error
)
if conn.conf.Timeout > 0 {
netConn, err = net.DialTimeout(typ, conn.conf.Addr, conn.conf.Timeout)
} else {
netConn, err = net.Dial(typ, conn.conf.Addr)
}
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/errors/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ var (
// to trigger a resend.
// See https://github.com/go-sql-driver/mysql/pull/302
ErrBadConnNoWrite = errors.New("bad connection")

ErrUnexpectedRead = errors.New("unexpected read from socket")
)
64 changes: 64 additions & 0 deletions pkg/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"net"
"strings"
"sync"
"syscall"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -117,6 +118,9 @@ type Conn struct {
// currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively.
// It can be allocated from bufPool or heap and should be recycled in the same manner.
currentEphemeralBuffer *[]byte

ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
}

// NewConn is an internal method to create a Conn. Used by client and server
Expand Down Expand Up @@ -474,6 +478,17 @@ func (c *Conn) WritePacket(data []byte) error {
w, unget := c.getWriter()
defer unget()

if c.ReadTimeout != 0 {
err := c.conn.SetReadDeadline(time.Now().Add(c.ReadTimeout))
if err != nil {
return err
}
}
err := connCheck(c.conn)
if err != nil {
return err
}

for {
// Packet length is capped to MaxPacketSize.
packetLength := length
Expand All @@ -487,6 +502,13 @@ func (c *Conn) WritePacket(data []byte) error {
header[1] = byte(packetLength >> 8)
header[2] = byte(packetLength >> 16)
header[3] = c.sequence

if c.WriteTimeout > 0 {
if err := c.conn.SetWriteDeadline(time.Now().Add(c.WriteTimeout)); err != nil {
return err
}
}

if n, err := w.Write(header[:]); err != nil {
return errors.Wrapf(err, "Write(header) failed")
} else if n != 4 {
Expand Down Expand Up @@ -995,6 +1017,14 @@ func (c *Conn) SetUserName(userName string) {
c.userName = userName
}

func (c *Conn) SetReadTimeout(readTimeout time.Duration) {
c.ReadTimeout = readTimeout
}

func (c *Conn) SetWriteTimeout(writeTimeout time.Duration) {
c.WriteTimeout = writeTimeout
}

// RemoteAddr returns the underlying socket RemoteAddr().
func (c *Conn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
Expand Down Expand Up @@ -1034,3 +1064,37 @@ func (c *Conn) Close() {
func (c *Conn) IsClosed() bool {
return c.closed.Get()
}

func connCheck(conn net.Conn) error {
var sysErr error

sysConn, ok := conn.(syscall.Conn)
if !ok {
return nil
}
rawConn, err := sysConn.SyscallConn()
if err != nil {
return err
}

err = rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
switch {
case n == 0 && err == nil:
sysErr = io.EOF
case n > 0:
sysErr = err2.ErrUnexpectedRead
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
sysErr = nil
default:
sysErr = err
}
return true
})
if err != nil {
return err
}

return sysErr
}

0 comments on commit 59b9ff1

Please sign in to comment.