Skip to content

Commit

Permalink
debug: listen capabilities fix:#317 (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
Penglq committed Jul 31, 2022
1 parent e0d604b commit 9279009
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
15 changes: 15 additions & 0 deletions pkg/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,21 @@ type Conn struct {
// It is only used by the server. These flags can be changed
// by Handler methods.
StatusFlags uint16

// Capabilities is the current set of features this connection
// is using. It is the features that are both supported by
// the client and the server, and currently in use.
// It is set during the initial handshake.
//
// It is only used for CapabilityClientDeprecateEOF
// and CapabilityClientFoundRows.
Capabilities uint32

// characterSet is the character set used by the other side of the
// connection.
// It is set during the initial handshake.
// See the values in constants.go.
CharacterSet uint8
}

// newConn is an internal method to create a Conn. Used by client and server
Expand Down
16 changes: 8 additions & 8 deletions pkg/mysql/execute_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ func (l *Listener) handleQuery(c *Conn, ctx *proto.Context) error {

fields, _ := ds.Fields()

if err = c.writeFields(l.capabilities, fields); err != nil {
if err = c.writeFields(fields); err != nil {
log.Errorf("write fields error %v: %v", ctx.ConnectionID, err)
return err
}
if err = c.writeDataset(ds); err != nil {
log.Errorf("write dataset error %v: %v", ctx.ConnectionID, err)
return err
}
if err = c.writeEndResult(l.capabilities, false, 0, 0, warn); err != nil {
if err = c.writeEndResult(false, 0, 0, warn); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return err
}
Expand Down Expand Up @@ -239,13 +239,13 @@ func (l *Listener) handleStmtExecute(c *Conn, ctx *proto.Context) error {

fields, _ := ds.Fields()

if err = c.writeFields(l.capabilities, fields); err != nil {
if err = c.writeFields(fields); err != nil {
return err
}
if err = c.writeDatasetBinary(ds); err != nil {
return err
}
if err = c.writeEndResult(l.capabilities, false, 0, 0, warn); err != nil {
if err = c.writeEndResult(false, 0, 0, warn); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return err
}
Expand Down Expand Up @@ -296,7 +296,7 @@ func (l *Listener) handlePrepare(c *Conn, ctx *proto.Context) error {

l.stmts.Store(statementID, stmt)

return c.writePrepare(l.capabilities, stmt)
return c.writePrepare(stmt)
}

func (l *Listener) handleStmtReset(c *Conn, ctx *proto.Context) error {
Expand All @@ -317,17 +317,17 @@ func (l *Listener) handleSetOption(c *Conn, ctx *proto.Context) error {
if ok {
switch operation {
case 0:
l.capabilities |= mysql.CapabilityClientMultiStatements
c.Capabilities |= mysql.CapabilityClientMultiStatements
case 1:
l.capabilities &^= mysql.CapabilityClientMultiStatements
c.Capabilities &^= mysql.CapabilityClientMultiStatements
default:
log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", ctx.ConnectionID, ctx.Data)
if err := c.writeErrorPacket(mysql.ERUnknownComError, mysql.SSUnknownComError, "error handling packet: %v", ctx.Data); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
}
}
if err := c.writeEndResult(l.capabilities, false, 0, 0, 0); err != nil {
if err := c.writeEndResult(false, 0, 0, 0); err != nil {
log.Errorf("Error writeEndResult error %v ", err)
return err
}
Expand Down
15 changes: 9 additions & 6 deletions pkg/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32) {
return
}

c.Capabilities = l.capabilities
c.CharacterSet = l.characterSet

// Negotiation worked, send OK packet.
if err = c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Cannot write OK packet to %s: %v", c, err)
Expand Down Expand Up @@ -1054,7 +1057,7 @@ func (c *Conn) writeColumnDefinition(field *Field) error {

// writeFields writes the fields of a Result. It should be called only
// if there are valid Columns in the result.
func (c *Conn) writeFields(capabilities uint32, fields []proto.Field) error {
func (c *Conn) writeFields(fields []proto.Field) error {
// Send the number of fields first.
if err := c.sendColumnCount(uint64(len(fields))); err != nil {
return err
Expand All @@ -1069,7 +1072,7 @@ func (c *Conn) writeFields(capabilities uint32, fields []proto.Field) error {
}

// Now send an EOF packet.
if capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
if c.Capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
// With CapabilityClientDeprecateEOF, we do not send this EOF.
if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil {
return err
Expand Down Expand Up @@ -1111,14 +1114,14 @@ func (c *Conn) writeDataset(ds proto.Dataset) error {

// writeEndResult concludes the sending of a Result.
// if more is set to true, then it means there are more results afterwords
func (c *Conn) writeEndResult(capabilities uint32, more bool, affectedRows, lastInsertID uint64, warnings uint16) error {
func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warnings uint16) error {
// Send either an EOF, or an OK packet.
// See doc.go.
flags := c.StatusFlags
if more {
flags |= mysql.ServerMoreResultsExists
}
if capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
if c.Capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
if err := c.writeEOFPacket(flags, warnings); err != nil {
return err
}
Expand All @@ -1133,7 +1136,7 @@ func (c *Conn) writeEndResult(capabilities uint32, more bool, affectedRows, last
}

// writePrepare writes a prepared query response to the wire.
func (c *Conn) writePrepare(capabilities uint32, prepare *proto.Stmt) error {
func (c *Conn) writePrepare(prepare *proto.Stmt) error {
paramsCount := prepare.ParamsCount

data := c.startEphemeralPacket(12)
Expand Down Expand Up @@ -1163,7 +1166,7 @@ func (c *Conn) writePrepare(capabilities uint32, prepare *proto.Stmt) error {
}

// Now send an EOF packet.
if capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
if c.Capabilities&mysql.CapabilityClientDeprecateEOF == 0 {
// With CapabilityClientDeprecateEOF, we do not send this EOF.
if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil {
return err
Expand Down

0 comments on commit 9279009

Please sign in to comment.