Skip to content

Commit

Permalink
merge with thomaslee 838eced
Browse files Browse the repository at this point in the history
  • Loading branch information
Phil Bayfield committed Dec 12, 2010
1 parent 96972c4 commit c431804
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 71 deletions.
9 changes: 3 additions & 6 deletions mysql.go
Expand Up @@ -12,7 +12,6 @@ import (
"bufio"
"os"
"log"
"reflect"
"sync"
)

Expand Down Expand Up @@ -403,11 +402,9 @@ func (mysql *MySQL) parseParams(p []interface{}) {
}
// Reflect 5th param to determine if it is port or socket
if len(p) > 4 {
v := reflect.NewValue(p[4])
if v.Type().Name() == "int" {
mysql.auth.port = v.Interface().(int)
} else if v.Type().Name() == "string" {
mysql.auth.socket = v.Interface().(string)
switch v := p[4].(type) {
case int: mysql.auth.port = v
case string: mysql.auth.socket = v
}
}
return
Expand Down
112 changes: 47 additions & 65 deletions mysql_packet.go
Expand Up @@ -10,7 +10,6 @@ import (
"bufio"
"os"
"crypto/sha1"
"reflect"
"strconv"
"math"
)
Expand Down Expand Up @@ -395,22 +394,15 @@ func (pkt *packetCommand) write(writer *bufio.Writer) (err os.Error) {
pkt.header = new(packetHeader)
pkt.header.length = 1
// Calculate packet length
var v reflect.Value
for i := 0; i < len(pkt.args); i++ {
v = reflect.NewValue(pkt.args[i])
switch v.Type().Name() {
switch v := pkt.args[i].(type) {
default:
return os.ErrorString("Unsupported type")
case "string":
pkt.header.length += uint32(len(v.Interface().(string)))
case "uint8":
pkt.header.length += 1
case "uint16":
pkt.header.length += 2
case "uint32":
pkt.header.length += 4
case "uint64":
pkt.header.length += 8
case string: pkt.header.length += uint32(len(v))
case uint8: pkt.header.length += 1
case uint16: pkt.header.length += 2
case uint32: pkt.header.length += 4
case uint64: pkt.header.length += 8
}
}
pkt.header.sequence = 0
Expand All @@ -425,18 +417,12 @@ func (pkt *packetCommand) write(writer *bufio.Writer) (err os.Error) {
}
// Write params
for i := 0; i < len(pkt.args); i++ {
v = reflect.NewValue(pkt.args[i])
switch v.Type().Name() {
case "string":
_, err = writer.WriteString(v.Interface().(string))
case "uint8":
err = pkt.writeNumber(writer, uint64(v.Interface().(uint8)), 1)
case "uint16":
err = pkt.writeNumber(writer, uint64(v.Interface().(uint16)), 2)
case "uint32":
err = pkt.writeNumber(writer, uint64(v.Interface().(uint32)), 4)
case "uint64":
err = pkt.writeNumber(writer, uint64(v.Interface().(uint64)), 8)
switch v := pkt.args[i].(type) {
case string: _, err = writer.WriteString(v)
case uint8: err = pkt.writeNumber(writer, uint64(v), 1)
case uint16: err = pkt.writeNumber(writer, uint64(v), 2)
case uint32: err = pkt.writeNumber(writer, uint64(v), 4)
case uint64: err = pkt.writeNumber(writer, v, 8)
}
if err != nil {
return
Expand Down Expand Up @@ -820,12 +806,10 @@ type packetExecute struct {
*/
func (pkt *packetExecute) encodeNullBits(params []interface{}) {
pkt.nullBitMap = make([]byte, (len(params)+7)/8)
var v reflect.Value
var bitMap uint64 = 0
// Check if params are null (nil)
for i := 0; i < len(params); i++ {
v = reflect.NewValue(params[i])
if reflect.Indirect(v) == nil {
if params[i] == nil {
bitMap += 1 << uint(i)
}
}
Expand All @@ -843,97 +827,95 @@ func (pkt *packetExecute) encodeParams(params []interface{}) {
pkt.paramData = make([][]byte, len(params))
// Add all param types
for i := 0; i < len(params); i++ {
var v reflect.Value
var n uint16
// Reflect param
v = reflect.NewValue(params[i])
v := params[i]
// Check for nils (NULL)
if reflect.Indirect(v) == nil {
if v == nil {
n = uint16(FIELD_TYPE_NULL)
} else {
// Match go types to MySQL types
switch v.Type().Name() {
switch value := v.(type) {
// Strings should be length coded binary
case "string":
case string:
n = uint16(FIELD_TYPE_STRING)
bytes, length := pkt.packString(v.Interface().(string))
bytes, length := pkt.packString(value)
pkt.paramData[i] = bytes
pkt.paramLength += uint32(length)
// Unsigned ints simple binary encoded
case "uint":
case uint:
// uint can be 32 or 64 bits
if strconv.IntSize == 32 {
n = uint16(FIELD_TYPE_LONG)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(uint)), 4)
pkt.paramData[i] = pkt.packNumber(uint64(value), 4)
pkt.paramLength += 4
} else {
n = uint16(FIELD_TYPE_LONGLONG)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(uint)), 8)
pkt.paramData[i] = pkt.packNumber(uint64(value), 8)
pkt.paramLength += 8
}
case "uint8":
case uint8:
n = uint16(FIELD_TYPE_TINY)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(uint8)), 1)
pkt.paramData[i] = pkt.packNumber(uint64(value), 1)
pkt.paramLength++
case "uint16":
case uint16:
n = uint16(FIELD_TYPE_SHORT)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(uint16)), 2)
pkt.paramData[i] = pkt.packNumber(uint64(value), 2)
pkt.paramLength += 2
case "uint32":
case uint32:
n = uint16(FIELD_TYPE_LONG)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(uint32)), 4)
pkt.paramData[i] = pkt.packNumber(uint64(value), 4)
pkt.paramLength += 4
case "uint64":
case uint64:
n = uint16(FIELD_TYPE_LONGLONG)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(uint64)), 8)
pkt.paramData[i] = pkt.packNumber(uint64(value), 8)
pkt.paramLength += 8
// Signed ints also encoded as uint as server 'should' determine their sign based on field type
case "int":
case int:
// int can be 32 or 64 bits
if strconv.IntSize == 32 {
n = uint16(FIELD_TYPE_LONG)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(int)), 4)
pkt.paramData[i] = pkt.packNumber(uint64(value), 4)
pkt.paramLength += 4
} else {
n = uint16(FIELD_TYPE_LONGLONG)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(int)), 8)
pkt.paramData[i] = pkt.packNumber(uint64(value), 8)
pkt.paramLength += 8
}
case "int8":
case int8:
n = uint16(FIELD_TYPE_TINY)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(int8)), 1)
pkt.paramData[i] = pkt.packNumber(uint64(value), 1)
pkt.paramLength++
case "int16":
case int16:
n = uint16(FIELD_TYPE_SHORT)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(int16)), 2)
pkt.paramData[i] = pkt.packNumber(uint64(value), 2)
pkt.paramLength += 2
case "int32":
case int32:
n = uint16(FIELD_TYPE_LONG)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(int32)), 4)
pkt.paramData[i] = pkt.packNumber(uint64(value), 4)
pkt.paramLength += 4
case "int64":
case int64:
n = uint16(FIELD_TYPE_LONGLONG)
pkt.paramData[i] = pkt.packNumber(uint64(v.Interface().(int64)), 8)
pkt.paramData[i] = pkt.packNumber(uint64(value), 8)
pkt.paramLength += 8
// Floats
case "float":
case float:
if strconv.FloatSize == 32 {
n = uint16(FIELD_TYPE_FLOAT)
pkt.paramData[i] = pkt.packNumber(uint64(math.Float32bits(float32(v.Interface().(float)))), 4)
pkt.paramData[i] = pkt.packNumber(uint64(math.Float32bits(float32(value))), 4)
pkt.paramLength += 4
} else {
n = uint16(FIELD_TYPE_DOUBLE)
pkt.paramData[i] = pkt.packNumber(uint64(math.Float64bits(float64(v.Interface().(float)))), 8)
pkt.paramData[i] = pkt.packNumber(uint64(math.Float64bits(float64(value))), 8)
pkt.paramLength += 8
}

case "float32":
case float32:
n = uint16(FIELD_TYPE_FLOAT)
pkt.paramData[i] = pkt.packNumber(uint64(math.Float32bits(float32(v.Interface().(float32)))), 4)
pkt.paramData[i] = pkt.packNumber(uint64(math.Float32bits(float32(value))), 4)
pkt.paramLength += 4
case "float64":
case float64:
n = uint16(FIELD_TYPE_DOUBLE)
pkt.paramData[i] = pkt.packNumber(uint64(math.Float64bits(float64(v.Interface().(float64)))), 8)
pkt.paramData[i] = pkt.packNumber(uint64(math.Float64bits(float64(value))), 8)
pkt.paramLength += 8
}
}
Expand Down

0 comments on commit c431804

Please sign in to comment.