Skip to content

Commit

Permalink
Support lastMessage() and rowsAffected()
Browse files Browse the repository at this point in the history
  • Loading branch information
Defined2014 committed Nov 17, 2023
1 parent f20b286 commit 9bc1824
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 77 deletions.
6 changes: 3 additions & 3 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,15 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte,
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
}

func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
func (mc *MysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
if err != nil {
return err
}
return mc.writeAuthSwitchPacket(enc)
}

func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
func (mc *MysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
switch plugin {
case "caching_sha2_password":
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
Expand Down Expand Up @@ -296,7 +296,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
}
}

func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
func (mc *MysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
// Read Result Packet
authData, newPlugin, err := mc.readAuthResult()
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func BenchmarkRoundtripBin(b *testing.B) {
}

func BenchmarkInterpolation(b *testing.B) {
mc := &mysqlConn{
mc := &MysqlConn{
cfg: &Config{
InterpolateParams: true,
Loc: time.UTC,
Expand Down
71 changes: 44 additions & 27 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import (
"time"
)

type mysqlConn struct {
type MysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
lastMessage string
cfg *Config
maxAllowedPacket int
maxWriteSize int
Expand All @@ -45,8 +46,18 @@ type mysqlConn struct {
closed atomicBool // set when conn is closed, before closech is closed
}

// RowsAffected returns the number of rows affected by the query.
func (mc *MysqlConn) RowsAffected() uint64 {
return mc.affectedRows
}

// LastMessage returns the database's last message.
func (mc *MysqlConn) LastMessage() string {
return mc.lastMessage
}

// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
func (mc *MysqlConn) handleParams() (err error) {
var cmdSet strings.Builder
for param, val := range mc.cfg.Params {
switch param {
Expand Down Expand Up @@ -89,7 +100,7 @@ func (mc *mysqlConn) handleParams() (err error) {
return
}

func (mc *mysqlConn) markBadConn(err error) error {
func (mc *MysqlConn) markBadConn(err error) error {
if mc == nil {
return err
}
Expand All @@ -99,11 +110,11 @@ func (mc *mysqlConn) markBadConn(err error) error {
return driver.ErrBadConn
}

func (mc *mysqlConn) Begin() (driver.Tx, error) {
func (mc *MysqlConn) Begin() (driver.Tx, error) {
return mc.begin(false)
}

func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
func (mc *MysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
Expand All @@ -121,7 +132,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
return nil, mc.markBadConn(err)
}

func (mc *mysqlConn) Close() (err error) {
func (mc *MysqlConn) Close() (err error) {
// Makes Close idempotent
if !mc.closed.Load() {
err = mc.writeCommandPacket(comQuit)
Expand All @@ -136,7 +147,7 @@ func (mc *mysqlConn) Close() (err error) {
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
func (mc *MysqlConn) cleanup() {
if mc.closed.Swap(true) {
return
}
Expand All @@ -151,7 +162,7 @@ func (mc *mysqlConn) cleanup() {
}
}

func (mc *mysqlConn) error() error {
func (mc *MysqlConn) error() error {
if mc.closed.Load() {
if err := mc.canceled.Value(); err != nil {
return err
Expand All @@ -161,7 +172,7 @@ func (mc *mysqlConn) error() error {
return nil
}

func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
func (mc *MysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
Expand Down Expand Up @@ -195,7 +206,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
return stmt, err
}

func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
func (mc *MysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
Expand Down Expand Up @@ -294,7 +305,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
return string(buf), nil
}

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
func (mc *MysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
Expand All @@ -312,6 +323,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
}
mc.affectedRows = 0
mc.insertId = 0
mc.lastMessage = ""

err := mc.exec(query)
if err == nil {
Expand All @@ -324,7 +336,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
func (mc *MysqlConn) exec(query string) error {
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
return mc.markBadConn(err)
Expand All @@ -351,11 +363,11 @@ func (mc *mysqlConn) exec(query string) error {
return mc.discardResults()
}

func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
func (mc *MysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return mc.query(query, args)
}

func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
func (mc *MysqlConn) query(query string, args []driver.Value) (*textRows, error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
Expand All @@ -371,6 +383,11 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
}
query = prepared
}

mc.affectedRows = 0
mc.insertId = 0
mc.lastMessage = ""

// Send command
err := mc.writeCommandPacketStr(comQuery, query)
if err == nil {
Expand Down Expand Up @@ -402,7 +419,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)

// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
func (mc *MysqlConn) getSystemVar(name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
return nil, err
Expand Down Expand Up @@ -431,13 +448,13 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
}

// finish is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) {
func (mc *MysqlConn) cancel(err error) {
mc.canceled.Set(err)
mc.cleanup()
}

// finish is called when the query has succeeded.
func (mc *mysqlConn) finish() {
func (mc *MysqlConn) finish() {
if !mc.watching || mc.finished == nil {
return
}
Expand All @@ -449,7 +466,7 @@ func (mc *mysqlConn) finish() {
}

// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
func (mc *MysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.Load() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
Expand All @@ -468,7 +485,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
}

// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
func (mc *MysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if mc.closed.Load() {
return nil, driver.ErrBadConn
}
Expand All @@ -492,7 +509,7 @@ func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver
return mc.begin(opts.ReadOnly)
}

func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
func (mc *MysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
Expand All @@ -511,7 +528,7 @@ func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driv
return rows, err
}

func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
func (mc *MysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
Expand All @@ -525,7 +542,7 @@ func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []drive
return mc.Exec(query, dargs)
}

func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
func (mc *MysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
Expand Down Expand Up @@ -578,7 +595,7 @@ func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue
return stmt.Exec(dargs)
}

func (mc *mysqlConn) watchCancel(ctx context.Context) error {
func (mc *MysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
Expand All @@ -603,7 +620,7 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error {
return nil
}

func (mc *mysqlConn) startWatcher() {
func (mc *MysqlConn) startWatcher() {
watcher := make(chan context.Context, 1)
mc.watcher = watcher
finished := make(chan struct{})
Expand All @@ -628,14 +645,14 @@ func (mc *mysqlConn) startWatcher() {
}()
}

func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
func (mc *MysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}

// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
func (mc *MysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.Load() {
return driver.ErrBadConn
}
Expand All @@ -645,6 +662,6 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {

// IsValid implements driver.Validator interface
// (From Go 1.15)
func (mc *mysqlConn) IsValid() bool {
func (mc *MysqlConn) IsValid() bool {
return !mc.closed.Load()
}
18 changes: 9 additions & 9 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
)

func TestInterpolateParams(t *testing.T) {
mc := &mysqlConn{
mc := &MysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
Expand All @@ -38,7 +38,7 @@ func TestInterpolateParams(t *testing.T) {
}

func TestInterpolateParamsJSONRawMessage(t *testing.T) {
mc := &mysqlConn{
mc := &MysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
Expand All @@ -65,7 +65,7 @@ func TestInterpolateParamsJSONRawMessage(t *testing.T) {
}

func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
mc := &mysqlConn{
mc := &MysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
Expand All @@ -82,7 +82,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
// We don't support placeholder in string literal for now.
// https://github.com/go-sql-driver/mysql/pull/490
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
mc := &mysqlConn{
mc := &MysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
Expand All @@ -98,7 +98,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) {
}

func TestInterpolateParamsUint64(t *testing.T) {
mc := &mysqlConn{
mc := &MysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
Expand All @@ -117,7 +117,7 @@ func TestInterpolateParamsUint64(t *testing.T) {

func TestCheckNamedValue(t *testing.T) {
value := driver.NamedValue{Value: ^uint64(0)}
x := &mysqlConn{}
x := &MysqlConn{}
err := x.CheckNamedValue(&value)

if err != nil {
Expand All @@ -132,7 +132,7 @@ func TestCheckNamedValue(t *testing.T) {
// TestCleanCancel tests passed context is cancelled at start.
// No packet should be sent. Connection should keep current status.
func TestCleanCancel(t *testing.T) {
mc := &mysqlConn{
mc := &MysqlConn{
closech: make(chan struct{}),
}
mc.startWatcher()
Expand All @@ -159,7 +159,7 @@ func TestCleanCancel(t *testing.T) {

func TestPingMarkBadConnection(t *testing.T) {
nc := badConnection{err: errors.New("boom")}
ms := &mysqlConn{
ms := &MysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
Expand All @@ -174,7 +174,7 @@ func TestPingMarkBadConnection(t *testing.T) {

func TestPingErrInvalidConn(t *testing.T) {
nc := badConnection{err: errors.New("failed to write"), n: 10}
ms := &mysqlConn{
ms := &MysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
Expand Down
2 changes: 1 addition & 1 deletion connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error

// New mysqlConn
mc := &mysqlConn{
mc := &MysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
Expand Down
Loading

0 comments on commit 9bc1824

Please sign in to comment.