Skip to content

Commit

Permalink
feat: support count aggregate func on sharded table (#183)
Browse files Browse the repository at this point in the history
* feat: support count aggregate func on sharded table
  • Loading branch information
dk-lockdown committed Jul 14, 2022
1 parent 9ea2fd4 commit 8290190
Show file tree
Hide file tree
Showing 11 changed files with 513 additions and 99 deletions.
6 changes: 3 additions & 3 deletions pkg/listener/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data
tracing.RecordErrorSpan(span, err)
return err
}
err = c.WriteRows(rlt)
err = c.WriteTextRows(rlt)
if err != nil {
tracing.RecordErrorSpan(span, err)
return err
Expand All @@ -623,7 +623,7 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data
tracing.RecordErrorSpan(span, err)
return err
}
err = c.WriteRowsDirect(rlt)
err = c.WriteRows(rlt)
if err != nil {
tracing.RecordErrorSpan(span, err)
return err
Expand Down Expand Up @@ -802,7 +802,7 @@ func (l *MysqlListener) ExecuteCommand(ctx context.Context, c *mysql.Conn, data
tracing.RecordErrorSpan(span, err)
return err
}
err = c.WriteRowsDirect(rlt)
err = c.WriteRows(rlt)
if err != nil {
tracing.RecordErrorSpan(span, err)
return err
Expand Down
86 changes: 59 additions & 27 deletions pkg/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ func (c *Conn) WriteFields(capabilities uint32, fields []*Field) error {
return nil
}

func (c *Conn) writeRow(row []*proto.Value) error {
func (c *Conn) writeTextRow(row []*proto.Value) error {
length := 0
for _, val := range row {
if val == nil || val.Val == nil {
Expand Down Expand Up @@ -639,8 +639,8 @@ func (c *Conn) writeRow(row []*proto.Value) error {
return c.WriteEphemeralPacket()
}

// WriteRows sends the rows of a Result.
func (c *Conn) WriteRows(result *Result) error {
// WriteTextRows sends the rows of a Result.
func (c *Conn) WriteTextRows(result *Result) error {
for {
row, err := result.Rows.Next()
if err != nil {
Expand All @@ -654,7 +654,7 @@ func (c *Conn) WriteRows(result *Result) error {
if err != nil {
return err
}
if err := c.writeRow(values); err != nil {
if err := c.writeTextRow(values); err != nil {
return err
}
}
Expand Down Expand Up @@ -702,13 +702,13 @@ func (c *Conn) WritePrepare(capabilities uint32, prepare *proto.Stmt) error {
return nil
}

// WriteBinaryRow writes text row to binary row
func (c *Conn) writeBinaryRow(fields []*Field, row []*proto.Value) error {
// writeTextToBinaryRows writes text row to binary row
func (c *Conn) writeTextToBinaryRows(fields []*Field, row []*proto.Value) error {
length := 0
nullBitMapLen := (len(fields) + 7 + 2) / 8
for _, val := range row {
if val != nil && val.Val != nil {
l, err := packet.Val2MySQLLen(val)
l, err := packet.TextVal2MySQLLen(val)
if err != nil {
return fmt.Errorf("internal value %v get MySQL value length error: %v", val, err)
}
Expand All @@ -733,7 +733,7 @@ func (c *Conn) writeBinaryRow(fields []*Field, row []*proto.Value) error {
bitPos := (i + 2) % 8
data[bytePos] |= 1 << uint(bitPos)
} else {
v, err := packet.Val2MySQL(val)
v, err := packet.TextVal2MySQL(val)
if err != nil {
c.RecycleWritePacket()
return fmt.Errorf("internal value %v to MySQL value error: %v", val, err)
Expand All @@ -749,26 +749,51 @@ func (c *Conn) writeBinaryRow(fields []*Field, row []*proto.Value) error {
return c.WriteEphemeralPacket()
}

// writeTextToBinaryRows sends the rows of a Result with binary form.
func (c *Conn) writeTextToBinaryRows(result *Result) error {
for {
row, err := result.Rows.Next()
if err != nil {
if err == io.EOF {
break
// writeBinaryRows writes text row to binary row
func (c *Conn) writeBinaryRows(fields []*Field, row []*proto.Value) error {
length := 0
nullBitMapLen := (len(fields) + 7 + 2) / 8
for _, val := range row {
if val != nil && val.Val != nil {
l, err := packet.BinaryVal2MySQLLen(val)
if err != nil {
return fmt.Errorf("internal value %v get MySQL value length error: %v", val, err)
}
return err
}
textRow := TextRow{Row: row}
values, err := textRow.Decode()
if err != nil {
return err
length += l
}
if err := c.writeBinaryRow(result.Fields, values); err != nil {
return err
}

length += nullBitMapLen + 1

data := c.StartEphemeralPacket(length)
pos := 0

pos = misc.WriteByte(data, pos, 0x00)

for i := 0; i < nullBitMapLen; i++ {
pos = misc.WriteByte(data, pos, 0x00)
}

for i, val := range row {
if val == nil || val.Val == nil {
bytePos := (i+2)/8 + 1
bitPos := (i + 2) % 8
data[bytePos] |= 1 << uint(bitPos)
} else {
v, err := packet.BinaryVal2MySQL(val)
if err != nil {
c.RecycleWritePacket()
return fmt.Errorf("internal value %v to MySQL value error: %v", val, err)
}
pos += copy(data[pos:], v)
}
}
return nil

if pos != length {
return fmt.Errorf("internal error packet row: got %v bytes but expected %v", pos, length)
}

return c.WriteEphemeralPacket()
}

func (c *Conn) WriteBinaryRows(result *Result) error {
Expand All @@ -787,10 +812,17 @@ func (c *Conn) WriteBinaryRows(result *Result) error {
return nil
}

func (c *Conn) WriteRowsDirect(result *MergeResult) error {
func (c *Conn) WriteRows(result *MergeResult) error {
for _, row := range result.Rows {
if err := c.WritePacket(row.Data()); err != nil {
return err
switch r := row.(type) {
case *TextRow:
if err := c.writeTextRow(r.Values); err != nil {
return err
}
case *BinaryRow:
if err := c.writeBinaryRows(result.Fields, r.Values); err != nil {
return err
}
}
}
return nil
Expand Down

0 comments on commit 8290190

Please sign in to comment.