Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP

Loading…

Couple little things #54

Open
wants to merge 6 commits into from

3 participants

This page is out of date. Refresh to see the latest.
Showing with 248 additions and 102 deletions.
  1. +5 −0 .gitignore
  2. +178 −73 mysql_test.go
  3. +17 −2 result.go
  4. +48 −27 statement.go
View
5 .gitignore
@@ -0,0 +1,5 @@
+*~
+*.out
+_test*
+_obj*
+_go*
View
251 mysql_test.go
@@ -10,18 +10,21 @@ import (
"os"
"rand"
"strconv"
+ "sync"
"testing"
)
-const (
- // Testing credentials, run the following on server client prior to running:
- // create database gomysql_test;
- // create database gomysql_test2;
- // create database gomysql_test3;
- // create user gomysql_test@localhost identified by 'abc123';
- // grant all privileges on gomysql_test.* to gomysql_test@localhost;
- // grant all privileges on gomysql_test2.* to gomysql_test@localhost;
+const instructions = `To run the GoMySQL tests, run the following on the server first:
+
+ create database gomysql_test;
+ create database gomysql_test2;
+ create database gomysql_test3;
+ create user gomysql_test@localhost identified by 'abc123';
+ grant all privileges on gomysql_test.* to gomysql_test@localhost;
+ grant all privileges on gomysql_test2.* to gomysql_test@localhost;
+`
+const (
// Testing settings
TEST_HOST = "localhost"
TEST_PORT = "3306"
@@ -42,6 +45,7 @@ const (
UPDATE_SIMPLE = "UPDATE simple SET `text` = '%s', `datetime` = NOW() WHERE id = %d"
UPDATE_SIMPLE_STMT = "UPDATE simple SET `text` = ?, `datetime` = NOW() WHERE id = ?"
DROP_SIMPLE = "DROP TABLE `simple`"
+ DROP_SIMPLE_MAYBE = "DROP TABLE IF EXISTS `simple`"
// All types table queries
CREATE_ALLTYPES = "CREATE TABLE `all_types` (`id` SERIAL NOT NULL, `tiny_int` TINYINT NOT NULL, `tiny_uint` TINYINT UNSIGNED NOT NULL, `small_int` SMALLINT NOT NULL, `small_uint` SMALLINT UNSIGNED NOT NULL, `medium_int` MEDIUMINT NOT NULL, `medium_uint` MEDIUMINT UNSIGNED NOT NULL, `int` INT NOT NULL, `uint` INT UNSIGNED NOT NULL, `big_int` BIGINT NOT NULL, `big_uint` BIGINT UNSIGNED NOT NULL, `decimal` DECIMAL(10,4) NOT NULL, `float` FLOAT NOT NULL, `double` DOUBLE NOT NULL, `real` REAL NOT NULL, `bit` BIT(32) NOT NULL, `boolean` BOOLEAN NOT NULL, `date` DATE NOT NULL, `datetime` DATETIME NOT NULL, `timestamp` TIMESTAMP NOT NULL, `time` TIME NOT NULL, `year` YEAR NOT NULL, `char` CHAR(32) NOT NULL, `varchar` VARCHAR(32) NOT NULL, `tiny_text` TINYTEXT NOT NULL, `text` TEXT NOT NULL, `medium_text` MEDIUMTEXT NOT NULL, `long_text` LONGTEXT NOT NULL, `binary` BINARY(32) NOT NULL, `var_binary` VARBINARY(32) NOT NULL, `tiny_blob` TINYBLOB NOT NULL, `medium_blob` MEDIUMBLOB NOT NULL, `blob` BLOB NOT NULL, `long_blob` LONGBLOB NOT NULL, `enum` ENUM('a','b','c','d','e') NOT NULL, `set` SET('a','b','c','d','e') NOT NULL, `geometry` GEOMETRY NOT NULL) ENGINE = InnoDB CHARACTER SET utf8 COLLATE utf8_unicode_ci COMMENT = 'GoMySQL Test Suite All Types Table'"
@@ -49,8 +53,10 @@ const (
)
var (
- db *Client
- err os.Error
+ db *Client
+ err os.Error
+ checkOnce sync.Once
+ skipTests bool
)
type SimpleRow struct {
@@ -61,8 +67,40 @@ type SimpleRow struct {
Date string
}
+func verifyConnections() {
+ db, err = DialTCP(TEST_HOST, TEST_USER, TEST_PASSWD, TEST_DBNAME)
+ if db != nil {
+ db.Close()
+ }
+ if err != nil {
+ skipTests = true
+ os.Stderr.Write([]byte(instructions))
+ return
+ }
+ db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME)
+ if db != nil {
+ db.Close()
+ }
+ if err != nil {
+ skipTests = true
+ os.Stderr.Write([]byte(instructions))
+ return
+ }
+}
+
+func skipTest(t *testing.T) bool {
+ checkOnce.Do(verifyConnections)
+ if skipTests {
+ t.Logf("skipping test; see instructions")
+ }
+ return skipTests
+}
+
// Test connect to server via TCP
func TestDialTCP(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
t.Logf("Running DialTCP test to %s:%s", TEST_HOST, TEST_PORT)
db, err = DialTCP(TEST_HOST, TEST_USER, TEST_PASSWD, TEST_DBNAME)
if err != nil {
@@ -78,6 +116,9 @@ func TestDialTCP(t *testing.T) {
// Test connect to server via Unix socket
func TestDialUnix(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
t.Logf("Running DialUnix test to %s", TEST_SOCK)
db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME)
if err != nil {
@@ -93,6 +134,9 @@ func TestDialUnix(t *testing.T) {
// Test connect to server with unprivileged database
func TestDialUnixUnpriv(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
t.Logf("Running DialUnix test to unprivileged database %s", TEST_DBNAMEUP)
db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAMEUP)
if err != nil {
@@ -108,6 +152,9 @@ func TestDialUnixUnpriv(t *testing.T) {
// Test connect to server with nonexistant database
func TestDialUnixNonex(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
t.Logf("Running DialUnix test to nonexistant database %s", TEST_DBNAMEBAD)
db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAMEBAD)
if err != nil {
@@ -123,6 +170,9 @@ func TestDialUnixNonex(t *testing.T) {
// Test connect with bad password
func TestDialUnixBadPass(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
t.Logf("Running DialUnix test with bad password")
db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_BAD_PASSWD, TEST_DBNAME)
if err != nil {
@@ -138,20 +188,24 @@ func TestDialUnixBadPass(t *testing.T) {
// Test queries on a simple table (create database, select, insert, update, drop database)
func TestSimple(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
t.Logf("Running simple table tests")
db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME)
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Create table")
+ db.Query(DROP_SIMPLE_MAYBE)
err = db.Query(CREATE_SIMPLE)
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Insert 1000 records")
rowMap := make(map[uint64][]string)
for i := 0; i < 1000; i++ {
@@ -164,21 +218,21 @@ func TestSimple(t *testing.T) {
row := []string{fmt.Sprintf("%d", num), str1, str2}
rowMap[db.LastInsertId] = row
}
-
+
t.Logf("Select inserted data")
err = db.Query(SELECT_SIMPLE)
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Use result")
res, err := db.UseResult()
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Validate inserted data")
for {
row := res.FetchRow()
@@ -187,19 +241,23 @@ func TestSimple(t *testing.T) {
}
id := row[0].(uint64)
num, str1, str2 := strconv.Itoa64(row[1].(int64)), row[2].(string), string(row[3].([]byte))
- if rowMap[id][0] != num || rowMap[id][1] != str1 || rowMap[id][2] != str2 {
+ expectRow, ok := rowMap[id]
+ if !ok {
+ t.Fatalf("read unexpected row number %d", id)
+ }
+ if expectRow[0] != num || expectRow[1] != str1 || expectRow[2] != str2 {
t.Logf("String from database doesn't match local string")
t.Fail()
}
}
-
+
t.Logf("Free result")
err = res.Free()
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Update some records")
for i := uint64(0); i < 1000; i += 5 {
rowMap[i+1][2] = randString(256)
@@ -213,21 +271,21 @@ func TestSimple(t *testing.T) {
t.Fail()
}
}
-
+
t.Logf("Select updated data")
err = db.Query(SELECT_SIMPLE)
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Store result")
res, err = db.StoreResult()
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Validate updated data")
for {
row := res.FetchRow()
@@ -242,7 +300,7 @@ func TestSimple(t *testing.T) {
t.Fail()
}
}
-
+
t.Logf("Free result")
err = res.Free()
if err != nil {
@@ -256,7 +314,7 @@ func TestSimple(t *testing.T) {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Close connection")
err = db.Close()
if err != nil {
@@ -265,79 +323,79 @@ func TestSimple(t *testing.T) {
}
}
-// Test queries on a simple table (create database, select, insert, update, drop database) using a statement
-func TestSimpleStatement(t *testing.T) {
- t.Logf("Running simple table statement tests")
- db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME)
- if err != nil {
- t.Logf("Error %s", err)
- t.Fail()
- }
-
- t.Logf("Init statement")
+func insert1000Records(t *testing.T, db *Client) map[uint64][]string {
stmt, err := db.InitStmt()
if err != nil {
- t.Logf("Error %s", err)
- t.Fail()
+ t.Fatalf("InitStmt: %v", err)
}
-
- t.Logf("Prepare create table")
- err = stmt.Prepare(CREATE_SIMPLE)
- if err != nil {
- t.Logf("Error %s", err)
- t.Fail()
- }
-
- t.Logf("Execute create table")
- err = stmt.Execute()
- if err != nil {
- t.Logf("Error %s", err)
- t.Fail()
- }
-
- t.Logf("Prepare insert")
+
err = stmt.Prepare(INSERT_SIMPLE_STMT)
if err != nil {
- t.Logf("Error %s", err)
- t.Fail()
+ t.Logf("Prepare insert: %v", err)
}
-
+
t.Logf("Insert 1000 records")
rowMap := make(map[uint64][]string)
for i := 0; i < 1000; i++ {
num, str1, str2 := rand.Int(), randString(32), randString(128)
err = stmt.BindParams(num, str1, str2)
if err != nil {
- t.Logf("Error %s", err)
- t.Fail()
+ t.Fatalf("Error %s", err)
}
err = stmt.Execute()
if err != nil {
- t.Logf("Error %s", err)
- t.Fail()
+ t.Fatalf("Error %s", err)
}
row := []string{fmt.Sprintf("%d", num), str1, str2}
rowMap[stmt.LastInsertId] = row
}
-
+ return rowMap
+}
+
+// Test queries on a simple table (create database, select, insert, update, drop database) using a statement
+func TestSimpleStatement(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ t.Logf("Running simple table statement tests")
+ db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME)
+ if err != nil {
+ t.Logf("Error %s", err)
+ t.Fail()
+ }
+
+ db.Query(DROP_SIMPLE_MAYBE)
+ err := db.Query(CREATE_SIMPLE)
+ if err != nil {
+ t.Fatalf("create table: %v", err)
+ }
+ defer db.Query(DROP_SIMPLE)
+
+ rowMap := insert1000Records(t, db)
+
+ stmt, err := db.InitStmt()
+ if err != nil {
+ t.Fatalf("InitStmt: %v", err)
+ }
+
t.Logf("Prepare select")
err = stmt.Prepare(SELECT_SIMPLE)
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Execute select")
err = stmt.Execute()
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Bind result")
row := SimpleRow{}
stmt.BindResult(&row.Id, &row.Number, &row.String, &row.Text, &row.Date)
-
+
t.Logf("Validate inserted data")
for {
eof, err := stmt.Fetch()
@@ -353,21 +411,21 @@ func TestSimpleStatement(t *testing.T) {
t.Fail()
}
}
-
+
t.Logf("Reset statement")
err = stmt.Reset()
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Prepare update")
err = stmt.Prepare(UPDATE_SIMPLE_STMT)
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Update some records")
for i := uint64(0); i < 1000; i += 5 {
rowMap[i+1][2] = randString(256)
@@ -382,21 +440,21 @@ func TestSimpleStatement(t *testing.T) {
t.Fail()
}
}
-
+
t.Logf("Prepare select updated")
err = stmt.Prepare(SELECT_SIMPLE)
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Execute select updated")
err = stmt.Execute()
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Validate updated data")
for {
eof, err := stmt.Fetch()
@@ -412,35 +470,35 @@ func TestSimpleStatement(t *testing.T) {
t.Fail()
}
}
-
+
t.Logf("Free result")
err = stmt.FreeResult()
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Prepare drop")
err = stmt.Prepare(DROP_SIMPLE)
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Execute drop")
err = stmt.Execute()
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Close statement")
err = stmt.Close()
if err != nil {
t.Logf("Error %s", err)
t.Fail()
}
-
+
t.Logf("Close connection")
err = db.Close()
if err != nil {
@@ -449,6 +507,53 @@ func TestSimpleStatement(t *testing.T) {
}
}
+func TestStatementUseResult(t *testing.T) {
+ if skipTest(t) {
+ return
+ }
+ db, err := DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME)
+ if err != nil {
+ t.Fatalf("dial error: %v", err)
+ }
+ defer db.Close()
+
+ db.Query(DROP_SIMPLE_MAYBE)
+ err = db.Query(CREATE_SIMPLE)
+ if err != nil {
+ t.Fatalf("create table: %v", err)
+ }
+ defer db.Query(DROP_SIMPLE)
+
+ insert1000Records(t, db)
+ stmt, err := db.Prepare(SELECT_SIMPLE)
+ if err != nil {
+ t.Fatalf("Prepare select: %v", err)
+ }
+ err = stmt.Execute()
+ if err != nil {
+ t.Fatalf("Execute: %v", err)
+ }
+ res, err := stmt.UseResult()
+ if err != nil {
+ t.Fatalf("UseResult: %v", err)
+ }
+ nRows := 0
+ for {
+ row := res.FetchRow()
+ if row == nil {
+ break
+ }
+ nRows++
+ }
+ if nRows != 1000 {
+ t.Errorf("expected 1000 rows; got %d", nRows)
+ }
+ err = res.Free()
+ if err != nil {
+ t.Logf("Free result: %s", err)
+ }
+}
+
// Benchmark connect/handshake via TCP
func BenchmarkDialTCP(b *testing.B) {
for i := 0; i < b.N; i++ {
View
19 result.go
@@ -12,6 +12,10 @@ type Result struct {
// Pointer to the client
c *Client
+ // if non-nil, the Result came from a Statement, not
+ // via Client.Query.
+ s *Statement
+
// Fields
fieldCount uint64
fieldPos uint64
@@ -71,6 +75,13 @@ func (r *Result) RowCount() uint64 {
return 0
}
+func (r *Result) getRow() (eof bool, err os.Error) {
+ if r.s != nil {
+ return r.s.getRow()
+ }
+ return r.c.getRow()
+}
+
// Fetch a row
func (r *Result) FetchRow() Row {
// Stored result
@@ -85,7 +96,7 @@ func (r *Result) FetchRow() Row {
// Used result
if r.mode == RESULT_USED {
if r.allRead == false {
- eof, err := r.c.getRow()
+ eof, err := r.getRow()
if err != nil {
return nil
}
@@ -123,6 +134,10 @@ func (r *Result) FetchRows() []Row {
// Free the result
func (r *Result) Free() (err os.Error) {
- err = r.c.FreeResult()
+ if r.s != nil {
+ err = r.s.FreeResult()
+ } else {
+ err = r.c.FreeResult()
+ }
return
}
View
75 statement.go
@@ -123,7 +123,7 @@ func (s *Statement) BindParams(params ...interface{}) (err os.Error) {
var t FieldType
var d []byte
// Switch on type
- switch param.(type) {
+ switch p := param.(type) {
// Nil
case nil:
t = FIELD_TYPE_NULL
@@ -134,7 +134,7 @@ func (s *Statement) BindParams(params ...interface{}) (err os.Error) {
} else {
t = FIELD_TYPE_LONGLONG
}
- d = itob(param.(int))
+ d = itob(p)
// Uint
case uint:
if strconv.IntSize == 32 {
@@ -142,57 +142,57 @@ func (s *Statement) BindParams(params ...interface{}) (err os.Error) {
} else {
t = FIELD_TYPE_LONGLONG
}
- d = uitob(param.(uint))
+ d = uitob(p)
// Int8
case int8:
t = FIELD_TYPE_TINY
- d = []byte{byte(param.(int8))}
+ d = []byte{byte(p)}
// Uint8
case uint8:
t = FIELD_TYPE_TINY
- d = []byte{param.(uint8)}
+ d = []byte{p}
// Int16
case int16:
t = FIELD_TYPE_SHORT
- d = i16tob(param.(int16))
+ d = i16tob(p)
// Uint16
case uint16:
t = FIELD_TYPE_SHORT
- d = ui16tob(param.(uint16))
+ d = ui16tob(p)
// Int32
case int32:
t = FIELD_TYPE_LONG
- d = i32tob(param.(int32))
+ d = i32tob(p)
// Uint32
case uint32:
t = FIELD_TYPE_LONG
- d = ui32tob(param.(uint32))
+ d = ui32tob(p)
// Int64
case int64:
t = FIELD_TYPE_LONGLONG
- d = i64tob(param.(int64))
+ d = i64tob(p)
// Uint64
case uint64:
t = FIELD_TYPE_LONGLONG
- d = ui64tob(param.(uint64))
+ d = ui64tob(p)
// Float32
case float32:
t = FIELD_TYPE_FLOAT
- d = f32tob(param.(float32))
+ d = f32tob(p)
// Float64
case float64:
t = FIELD_TYPE_DOUBLE
- d = f64tob(param.(float64))
+ d = f64tob(p)
// String
case string:
t = FIELD_TYPE_STRING
- d = lcbtob(uint64(len(param.(string))))
- d = append(d, []byte(param.(string))...)
+ d = lcbtob(uint64(len(p)))
+ d = append(d, []byte(p)...)
// Byte array
case []byte:
t = FIELD_TYPE_BLOB
- d = lcbtob(uint64(len(param.([]byte))))
- d = append(d, param.([]byte)...)
+ d = lcbtob(uint64(len(p)))
+ d = append(d, p...)
// Other types
default:
return &ClientError{CR_UNSUPPORTED_PARAM_TYPE, s.c.fmtError(CR_UNSUPPORTED_PARAM_TYPE_STR, reflect.ValueOf(param).Type(), k)}
@@ -472,6 +472,27 @@ func (s *Statement) Fetch() (eof bool, err os.Error) {
return
}
+// Use result
+func (s *Statement) UseResult() (*Result, os.Error) {
+ // Log use result
+ s.c.log(1, "=== Begin use result ===")
+ // Check prepared
+ if !s.prepared {
+ return nil, &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR}
+ }
+ // Check result
+ if !s.checkResult() {
+ return nil, &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR}
+ }
+ // Check if result already used/stored
+ if s.result.mode != RESULT_UNUSED {
+ return nil, &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR}
+ }
+ s.result.mode = RESULT_USED
+ s.result.s = s // tell the result that we own it
+ return s.result, nil
+}
+
// Store result
func (s *Statement) StoreResult() (err os.Error) {
// Auto reconnect
@@ -694,32 +715,32 @@ func (s *Statement) getResult(types packetType) (eof bool, err os.Error) {
// Log read result
s.c.log(1, "Reading result packet from server")
// Get result packet
- p, err := s.c.r.readPacket(types)
+ pr, err := s.c.r.readPacket(types)
if err != nil {
return
}
// Process result packet
- switch p.(type) {
+ switch p := pr.(type) {
default:
err = &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR}
case *packetOK:
- err = handleOK(p.(*packetOK), s.c, &s.AffectedRows, &s.LastInsertId, &s.Warnings)
+ err = handleOK(p, s.c, &s.AffectedRows, &s.LastInsertId, &s.Warnings)
case *packetError:
- err = handleError(p.(*packetError), s.c)
+ err = handleError(p, s.c)
case *packetEOF:
eof = true
- err = handleEOF(p.(*packetEOF), s.c)
+ err = handleEOF(p, s.c)
case *packetPrepareOK:
- err = handlePrepareOK(p.(*packetPrepareOK), s.c, s)
+ err = handlePrepareOK(p, s.c, s)
case *packetParameter:
- err = handleParam(p.(*packetParameter), s.c)
+ err = handleParam(p, s.c)
case *packetField:
- err = handleField(p.(*packetField), s.c, s.result)
+ err = handleField(p, s.c, s.result)
case *packetResultSet:
s.result = &Result{c: s.c}
- err = handleResultSet(p.(*packetResultSet), s.c, s.result)
+ err = handleResultSet(p, s.c, s.result)
case *packetRowBinary:
- err = handleBinaryRow(p.(*packetRowBinary), s.c, s.result)
+ err = handleBinaryRow(p, s.c, s.result)
}
return
}
Something went wrong with that request. Please try again.