Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Fetching contributors…

Cannot retrieve contributors at this time

771 lines (738 sloc) 16.831 kb
// GoMySQL - A MySQL client library for Go
//
// Copyright 2010-2011 Phil Bayfield. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mysql
import (
"os"
"reflect"
"strconv"
)
// Prepared statement struct
type Statement struct {
// Client pointer
c *Client
// Statement status flags
prepared bool
preparedSql string
paramsBound bool
paramsRebound bool
// Statement id
statementId uint32
// Params
paramCount uint16
paramType [][]byte
paramData [][]byte
// Columns (fields)
columnCount uint64
// Result
AffectedRows uint64
LastInsertId uint64
Warnings uint16
result *Result
resultParams []interface{}
}
// Prepare new statement
func (s *Statement) Prepare(sql string) (err os.Error) {
// Auto reconnect
defer func() {
if err != nil && s.c.checkNet(err) && s.c.Reconnect {
s.c.log(1, "!!! Lost connection to server !!!")
s.c.connected = false
err = s.c.reconnect()
if err == nil {
err = s.Prepare(sql)
}
}
}()
// Log prepare
s.c.log(1, "=== Begin prepare '%s' ===", sql)
// Pre-run checks
if !s.c.checkConn() || s.checkResult() {
return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR}
}
// Reset client
s.reset()
// Send close command
err = s.c.command(COM_STMT_PREPARE, sql)
if err != nil {
return
}
// Read result from server
s.c.sequence++
_, err = s.getResult(PACKET_PREPARE_OK | PACKET_ERROR)
if err != nil {
return
}
// Read param packets
if s.paramCount > 0 {
for {
s.c.sequence++
eof, err := s.getResult(PACKET_PARAM | PACKET_EOF)
if err != nil {
return
}
if eof {
break
}
}
}
// Read field packets
if s.columnCount > 0 {
err = s.getFields()
if err != nil {
return
}
}
// Statement is preapred
s.prepared = true
s.preparedSql = sql
return
}
// Get number of params
func (s *Statement) ParamCount() uint16 {
return s.paramCount
}
// Bind params
func (s *Statement) BindParams(params ...interface{}) (err os.Error) {
// Check prepared
if !s.prepared {
return &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
}
// Check number of params is correct
if len(params) != int(s.paramCount) {
return &ClientError{CR_INVALID_PARAMETER_NO, CR_INVALID_PARAMETER_NO_STR}
}
// Reset params
s.paramType = [][]byte{}
s.paramData = [][]byte{}
// Convert params into bytes
for k, param := range params {
// Temp vars
var t FieldType
var d []byte
// Switch on type
switch param.(type) {
// Nil
case nil:
t = FIELD_TYPE_NULL
// Int
case int:
if strconv.IntSize == 32 {
t = FIELD_TYPE_LONG
} else {
t = FIELD_TYPE_LONGLONG
}
d = itob(param.(int))
// Uint
case uint:
if strconv.IntSize == 32 {
t = FIELD_TYPE_LONG
} else {
t = FIELD_TYPE_LONGLONG
}
d = uitob(param.(uint))
// Int8
case int8:
t = FIELD_TYPE_TINY
d = []byte{byte(param.(int8))}
// Uint8
case uint8:
t = FIELD_TYPE_TINY
d = []byte{param.(uint8)}
// Int16
case int16:
t = FIELD_TYPE_SHORT
d = i16tob(param.(int16))
// Uint16
case uint16:
t = FIELD_TYPE_SHORT
d = ui16tob(param.(uint16))
// Int32
case int32:
t = FIELD_TYPE_LONG
d = i32tob(param.(int32))
// Uint32
case uint32:
t = FIELD_TYPE_LONG
d = ui32tob(param.(uint32))
// Int64
case int64:
t = FIELD_TYPE_LONGLONG
d = i64tob(param.(int64))
// Uint64
case uint64:
t = FIELD_TYPE_LONGLONG
d = ui64tob(param.(uint64))
// Float32
case float32:
t = FIELD_TYPE_FLOAT
d = f32tob(param.(float32))
// Float64
case float64:
t = FIELD_TYPE_DOUBLE
d = f64tob(param.(float64))
// String
case string:
t = FIELD_TYPE_STRING
d = lcbtob(uint64(len(param.(string))))
d = append(d, []byte(param.(string))...)
// Byte array
case []byte:
t = FIELD_TYPE_BLOB
d = lcbtob(uint64(len(param.([]byte))))
d = append(d, param.([]byte)...)
// Other types
default:
return &ClientError{CR_UNSUPPORTED_PARAM_TYPE, s.c.fmtError(CR_UNSUPPORTED_PARAM_TYPE_STR, reflect.ValueOf(param).Type(), k)}
}
// Append values
s.paramType = append(s.paramType, []byte{byte(t), 0x0})
s.paramData = append(s.paramData, d)
}
// Flag params as bound
s.paramsBound = true
s.paramsRebound = true
return
}
// Send long data
func (s *Statement) SendLongData(num int, data []byte) (err os.Error) {
// Auto reconnect
defer func() {
err = s.c.simpleReconnect(err)
}()
// Log send long data
s.c.log(1, "=== Begin send long data ===")
// Check prepared
if !s.prepared {
return &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
}
// Pre-run checks
if !s.c.checkConn() || s.checkResult() {
return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR}
}
// Reset client
s.reset()
// Data position (if data is longer than max packet length
pos := 0
// Send data
for {
// Construct packet
p := &packetLongData{
command: uint8(COM_STMT_SEND_LONG_DATA),
statementId: s.statementId,
paramNumber: uint16(num),
}
// Add protocol and sequence
p.protocol = s.c.protocol
p.sequence = s.c.sequence
// Add data
if len(data[pos:]) > MAX_PACKET_SIZE-12 {
p.data = data[pos : MAX_PACKET_SIZE-12]
pos += MAX_PACKET_SIZE - 12
} else {
p.data = data[pos:]
pos += len(data[pos:])
}
// Write packet
err = s.c.w.writePacket(p)
if err != nil {
return
}
// Log write success
s.c.log(1, "[%d] Sent long data packet", p.sequence)
// Check if all data sent
if pos == len(data) {
break
}
// Increment sequence
s.c.sequence++
}
return
}
// Execute
func (s *Statement) Execute() (err os.Error) {
// Auto reconnect
defer func() {
if err != nil && s.c.checkNet(err) && s.c.Reconnect {
s.c.log(1, "!!! Lost connection to server !!!")
s.c.connected = false
err = s.c.reconnect()
if err == nil {
err = s.Prepare(s.preparedSql)
if err == nil {
err = s.Execute()
}
}
}
}()
// Log execute
s.c.log(1, "=== Begin execute ===")
// Check prepared
if !s.prepared {
return &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
}
// Check params bound
if s.paramCount > 0 && !s.paramsBound {
return &ClientError{CR_PARAMS_NOT_BOUND, CR_PARAMS_NOT_BOUND_STR}
}
// Pre-run checks
if !s.c.checkConn() || s.checkResult() {
return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR}
}
// Reset client
s.reset()
// Construct packet
p := &packetExecute{
command: byte(COM_STMT_EXECUTE),
statementId: s.statementId,
flags: byte(CURSOR_TYPE_NO_CURSOR),
iterationCount: 1,
nullBitMap: s.getNullBitMap(),
paramType: s.paramType,
paramData: s.paramData,
}
// Add protocol and sequence
p.protocol = s.c.protocol
p.sequence = s.c.sequence
// Add rebound flag
if s.paramsRebound {
p.newParamsBound = byte(1)
}
// Write packet
err = s.c.w.writePacket(p)
if err != nil {
return
}
// Log write success
s.c.log(1, "[%d] Sent execute packet", p.sequence)
// Read result from server
s.c.sequence++
_, err = s.getResult(PACKET_OK | PACKET_ERROR | PACKET_RESULT)
if err != nil || s.result == nil {
return
}
// Store fields
err = s.getFields()
// Unflag params rebound
s.paramsRebound = false
return
}
// Get field count
func (s *Statement) FieldCount() uint64 {
if s.checkResult() {
return s.result.fieldCount
}
return 0
}
// Fetch the next field
func (s *Statement) FetchColumn() *Field {
if s.checkResult() {
// Check if all fields have been fetched
if s.result.fieldPos < uint64(len(s.result.fields)) {
// Increment and return current field
s.result.fieldPos++
return s.result.fields[s.result.fieldPos-1]
}
}
return nil
}
// Fetch all fields
func (s *Statement) FetchColumns() []*Field {
if s.checkResult() {
return s.result.fields
}
return nil
}
// Bind result
func (s *Statement) BindResult(params ...interface{}) (err os.Error) {
s.resultParams = params
return
}
// Get row count
func (s *Statement) RowCount() uint64 {
// Stored mode
if s.checkResult() && s.result.mode == RESULT_STORED {
return uint64(len(s.result.rows))
}
return 0
}
// Fetch next row
func (s *Statement) Fetch() (eof bool, err os.Error) {
// Auto reconnect
defer func() {
err = s.c.simpleReconnect(err)
}()
// Log fetch
s.c.log(1, "=== Begin fetch ===")
// Check prepared
if !s.prepared {
return false, &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
}
// Check result
if !s.checkResult() {
return false, &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR}
}
var row Row
// Check result mode
switch s.result.mode {
// Used or unused result (needs fetching)
case RESULT_UNUSED, RESULT_USED:
s.result.mode = RESULT_USED
if s.result.allRead == true {
return true, nil
}
eof, err := s.getRow()
if err != nil {
return false, err
}
if eof {
s.result.allRead = true
return true, nil
}
row = s.result.rows[0]
// Stored result
case RESULT_STORED:
if s.result.rowPos >= uint64(len(s.result.rows)) {
return true, nil
}
row = s.result.rows[s.result.rowPos]
s.result.rowPos++
}
// Recover possible errors from type conversion
defer func() {
if e := recover(); e != nil {
err = &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR}
return
}
}()
// Iterate bound params and assign from row (partial set quicker this way)
for k, v := range s.resultParams {
switch t := v.(type) {
// Integer types
case *int:
*t = int(atoui64(row[k]))
case *uint:
*t = uint(atoui64(row[k]))
case *int8:
*t = int8(atoui64(row[k]))
case *uint8:
*t = uint8(atoui64(row[k]))
case *int16:
*t = int16(atoui64(row[k]))
case *uint16:
*t = uint16(atoui64(row[k]))
case *int32:
*t = int32(atoui64(row[k]))
case *uint32:
*t = uint32(atoui64(row[k]))
case *int64:
*t = int64(atoui64(row[k]))
case *uint64:
*t = atoui64(row[k])
// Floating point types
case *float32:
*t = float32(atof64(row[k]))
case *float64:
*t = atof64(row[k])
// Byte slice, assertion
case *[]byte:
*t = row[k].([]byte)
// Strings
case *string:
*t = atos(row[k])
// Date/time, assertion
case *Date:
*t = row[k].(Date)
case *Time:
*t = row[k].(Time)
case *DateTime:
*t = row[k].(DateTime)
}
}
return
}
// Store result
func (s *Statement) StoreResult() (err os.Error) {
// Auto reconnect
defer func() {
err = s.c.simpleReconnect(err)
}()
// Log store result
s.c.log(1, "=== Begin store result ===")
// Check prepared
if !s.prepared {
return &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
}
// Check if result already used/stored
if s.result.mode != RESULT_UNUSED {
return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR}
}
// Set storage mode
s.result.mode = RESULT_STORED
// Store all rows
err = s.getAllRows()
if err != nil {
return
}
s.result.allRead = true
return
}
// Free result
func (s *Statement) FreeResult() (err os.Error) {
// Auto reconnect
defer func() {
err = s.c.simpleReconnect(err)
}()
// Log free result
s.c.log(1, "=== Begin free result ===")
// Check prepared
if !s.prepared {
return &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
}
// Check result
if !s.checkResult() {
return &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR}
}
// Free the current result set
s.freeAll(false)
return
}
// More results
func (s *Statement) MoreResults() bool {
return s.c.MoreResults()
}
// Next result
func (s *Statement) NextResult() (more bool, err os.Error) {
// Auto reconnect
defer func() {
err = s.c.simpleReconnect(err)
}()
// Log next result
s.c.log(1, "=== Begin next result ===")
// Check prepared
if !s.prepared {
return false, &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
}
// Pre-run checks
if !s.c.checkConn() || s.checkResult() {
return false, &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR}
}
// Check for more results
more = s.MoreResults()
if !more {
return
}
// Read result from server
s.c.sequence++
_, err = s.getResult(PACKET_OK | PACKET_ERROR | PACKET_RESULT)
if err != nil || s.result == nil {
return
}
// Store fields
err = s.getFields()
return
}
// Reset statement
func (s *Statement) Reset() (err os.Error) {
// Auto reconnect
defer func() {
err = s.c.simpleReconnect(err)
}()
// Log next result
s.c.log(1, "=== Begin reset statement ===")
// Check prepared
if !s.prepared {
return &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
}
// Pre-run checks
if !s.c.checkConn() {
return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR}
}
// Free any results
if s.checkResult() {
err = s.freeAll(true)
}
// Reset client
s.reset()
// Send command
err = s.c.command(COM_STMT_RESET, s.statementId)
if err != nil {
return
}
// Read result from server
s.c.sequence++
_, err = s.getResult(PACKET_OK | PACKET_ERROR)
return
}
// Close statement
func (s *Statement) Close() (err os.Error) {
// Auto reconnect
defer func() {
err = s.c.simpleReconnect(err)
}()
// Log next result
s.c.log(1, "=== Begin close statement ===")
// Check prepared
if !s.prepared {
return &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
}
// Pre-run checks
if !s.c.checkConn() || s.checkResult() {
return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR}
}
// Reset client
s.reset()
// Send command
err = s.c.command(COM_STMT_RESET, s.statementId)
return
}
// Reset the statement
func (s *Statement) reset() {
s.AffectedRows = 0
s.LastInsertId = 0
s.Warnings = 0
s.result = nil
s.c.reset()
}
// Check if a result exists
func (s *Statement) checkResult() bool {
if s.result != nil {
return true
}
return false
}
// Get null bit map
func (s *Statement) getNullBitMap() (nbm []byte) {
nbm = make([]byte, (s.paramCount+7)/8)
bm := uint64(0)
// Check if params are null (nil)
for i := uint16(0); i < s.paramCount; i++ {
if s.paramType[i][0] == byte(FIELD_TYPE_NULL) {
bm += 1 << uint(i)
}
}
// Convert the uint64 value into bytes
for i := 0; i < len(nbm); i++ {
nbm[i] = byte(bm >> uint(i*8))
}
return
}
// Get all result fields
func (s *Statement) getFields() (err os.Error) {
// Loop till EOF
for {
s.c.sequence++
eof, err := s.getResult(PACKET_FIELD | PACKET_EOF)
if err != nil {
return
}
if eof {
break
}
}
return
}
// Get next row for a result
func (s *Statement) getRow() (eof bool, err os.Error) {
// Check for a valid result
if s.result == nil {
return false, &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR}
}
// Read next row packet or EOF
s.c.sequence++
eof, err = s.getResult(PACKET_ROW_BINARY | PACKET_EOF)
return
}
// Get all rows for the result
func (s *Statement) getAllRows() (err os.Error) {
for {
eof, err := s.getRow()
if err != nil {
return
}
if eof {
break
}
}
return
}
// Get result
func (s *Statement) getResult(types packetType) (eof bool, err os.Error) {
// Log read result
s.c.log(1, "Reading result packet from server")
// Get result packet
p, err := s.c.r.readPacket(types)
if err != nil {
return
}
// Process result packet
switch p.(type) {
default:
err = &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR}
case *packetOK:
err = handleOK(p.(*packetOK), s.c, &s.AffectedRows, &s.LastInsertId, &s.Warnings)
case *packetError:
err = handleError(p.(*packetError), s.c)
case *packetEOF:
eof = true
err = handleEOF(p.(*packetEOF), s.c)
case *packetPrepareOK:
err = handlePrepareOK(p.(*packetPrepareOK), s.c, s)
case *packetParameter:
err = handleParam(p.(*packetParameter), s.c)
case *packetField:
err = handleField(p.(*packetField), s.c, s.result)
case *packetResultSet:
s.result = &Result{c: s.c}
err = handleResultSet(p.(*packetResultSet), s.c, s.result)
case *packetRowBinary:
err = handleBinaryRow(p.(*packetRowBinary), s.c, s.result)
}
return
}
// Free any result sets waiting to be read
func (s *Statement) freeAll(next bool) (err os.Error) {
// Check for unread rows
if !s.result.allRead {
// Read all rows
err = s.getAllRows()
if err != nil {
return
}
}
// Unset the result
s.result = nil
// Check for next result
if next {
for {
// Check if more results exist
if !s.c.MoreResults() {
break
}
// Get next result
s.c.sequence++
_, err = s.getResult(PACKET_OK | PACKET_ERROR | PACKET_RESULT)
if err != nil {
return
}
if s.result == nil {
continue
}
// Set result mode to RESULT_FREE
s.result.mode = RESULT_FREE
// Read fields
err = s.getFields()
if err != nil {
return
}
// Read rows
err = s.getAllRows()
if err != nil {
return
}
}
}
return
}
Jump to Line
Something went wrong with that request. Please try again.